Skip to content

Commit

Permalink
don't return errors on too large requests on a reversed bitreader (#58)
Browse files Browse the repository at this point in the history
* don't return errors on too large requests on a reversed bitreader

* introduce checks for maximum symbol in the FSE table decoding
  • Loading branch information
KillingSpark committed May 30, 2024
1 parent 53e7b1a commit 944b391
Show file tree
Hide file tree
Showing 12 changed files with 101 additions and 101 deletions.
5 changes: 4 additions & 1 deletion Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,7 @@ This document records the changes made between versions, starting with version 0
* The FrameDecoder is now Send + Sync (RingBuffer impls these traits now)

# After 0.6.0
* Small fix in the zstd binary, progress tracking was slighty off for skippable frames resulting in an error only when the last frame in a file was skippable
* Small fix in the zstd binary, progress tracking was slighty off for skippable frames resulting in an error only when the last frame in a file was skippable
* Small performance improvement by reorganizing code with `#[cold]` annotations
* Documentation for `StreamDecoder` mentioning the limitations around multiple frames (https://github.com/Sorseg)
* Documentation around skippable frames (https://github.com/Sorseg)
4 changes: 2 additions & 2 deletions benches/reversedbitreader_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use ruzstd::decoding::bit_reader_reverse::BitReaderReversed;
fn do_all_accesses(br: &mut BitReaderReversed, accesses: &[u8]) -> u64 {
let mut sum = 0;
for x in accesses {
sum += br.get_bits(*x).unwrap();
sum += br.get_bits(*x);
}
let _ = black_box(br);
sum
Expand All @@ -24,7 +24,7 @@ fn criterion_benchmark(c: &mut Criterion) {
let mut br = BitReaderReversed::new(&rand_vec);
while br.bits_remaining() > 0 {
let x = rng.gen_range(0..20);
br.get_bits(x).unwrap();
br.get_bits(x);
access_vec.push(x);
}

Expand Down
Binary file not shown.
2 changes: 1 addition & 1 deletion fuzz/fuzz_targets/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ extern crate ruzstd;
use ruzstd::frame_decoder;

fuzz_target!(|data: &[u8]| {
let mut content = data.clone();
let mut content = data;
let mut frame_dec = frame_decoder::FrameDecoder::new();

match frame_dec.reset(&mut content){
Expand Down
4 changes: 4 additions & 0 deletions src/blocks/sequence_section.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
//! Utilities and representations for the second half of a block, the sequence section.
//! This section copies literals from the literals section into the decompressed output.

pub(crate) const MAX_LITERAL_LENGTH_CODE: u8 = 35;
pub(crate) const MAX_MATCH_LENGTH_CODE: u8 = 52;
pub(crate) const MAX_OFFSET_CODE: u8 = 31;

pub struct SequencesHeader {
pub num_sequences: u32,
pub modes: Option<CompressionModes>,
Expand Down
53 changes: 18 additions & 35 deletions src/decoding/bit_reader_reverse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,40 +111,34 @@ impl<'s> BitReaderReversed<'s> {
/// Read `n` number of bits from the source. Returns an error if the reader
/// requests more bits than remain for reading.
#[inline(always)]
pub fn get_bits(&mut self, n: u8) -> Result<u64, GetBitsError> {
pub fn get_bits(&mut self, n: u8) -> u64 {
if n == 0 {
return Ok(0);
return 0;
}
if self.bits_in_container >= n {
return Ok(self.get_bits_unchecked(n));
return self.get_bits_unchecked(n);
}

self.get_bits_cold(n)
}

#[cold]
fn get_bits_cold(&mut self, n: u8) -> Result<u64, GetBitsError> {
if n > 56 {
return Err(GetBitsError::TooManyBits {
num_requested_bits: usize::from(n),
limit: 56,
});
}

fn get_bits_cold(&mut self, n: u8) -> u64 {
let n = u8::min(n, 56);
let signed_n = n as isize;

if self.bits_remaining() <= 0 {
self.idx -= signed_n;
return Ok(0);
return 0;
}

if self.bits_remaining() < signed_n {
let emulated_read_shift = signed_n - self.bits_remaining();
let v = self.get_bits(self.bits_remaining() as u8)?;
let v = self.get_bits(self.bits_remaining() as u8);
debug_assert!(self.idx == 0);
let value = v << emulated_read_shift;
let value = v.wrapping_shl(emulated_read_shift as u32);
self.idx -= emulated_read_shift;
return Ok(value);
return value;
}

while (self.bits_in_container < n) && self.idx > 0 {
Expand All @@ -155,23 +149,18 @@ impl<'s> BitReaderReversed<'s> {

//if we reach this point there are enough bits in the container

Ok(self.get_bits_unchecked(n))
self.get_bits_unchecked(n)
}

#[inline(always)]
pub fn get_bits_triple(
&mut self,
n1: u8,
n2: u8,
n3: u8,
) -> Result<(u64, u64, u64), GetBitsError> {
pub fn get_bits_triple(&mut self, n1: u8, n2: u8, n3: u8) -> (u64, u64, u64) {
let sum = n1 as usize + n2 as usize + n3 as usize;
if sum == 0 {
return Ok((0, 0, 0));
return (0, 0, 0);
}
if sum > 56 {
// try and get the values separately
return Ok((self.get_bits(n1)?, self.get_bits(n2)?, self.get_bits(n3)?));
return (self.get_bits(n1), self.get_bits(n2), self.get_bits(n3));
}
let sum = sum as u8;

Expand All @@ -192,29 +181,23 @@ impl<'s> BitReaderReversed<'s> {
self.get_bits_unchecked(n3)
};

return Ok((v1, v2, v3));
return (v1, v2, v3);
}

self.get_bits_triple_cold(n1, n2, n3, sum)
}

#[cold]
fn get_bits_triple_cold(
&mut self,
n1: u8,
n2: u8,
n3: u8,
sum: u8,
) -> Result<(u64, u64, u64), GetBitsError> {
fn get_bits_triple_cold(&mut self, n1: u8, n2: u8, n3: u8, sum: u8) -> (u64, u64, u64) {
let sum_signed = sum as isize;

if self.bits_remaining() <= 0 {
self.idx -= sum_signed;
return Ok((0, 0, 0));
return (0, 0, 0);
}

if self.bits_remaining() < sum_signed {
return Ok((self.get_bits(n1)?, self.get_bits(n2)?, self.get_bits(n3)?));
return (self.get_bits(n1), self.get_bits(n2), self.get_bits(n3));
}

while (self.bits_in_container < sum) && self.idx > 0 {
Expand All @@ -241,7 +224,7 @@ impl<'s> BitReaderReversed<'s> {
self.get_bits_unchecked(n3)
};

Ok((v1, v2, v3))
(v1, v2, v3)
}

#[inline(always)]
Expand Down
12 changes: 6 additions & 6 deletions src/decoding/literals_section_decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ fn decompress_literals(
//skip the 0 padding at the end of the last byte of the bit stream and throw away the first 1 found
let mut skipped_bits = 0;
loop {
let val = br.get_bits(1)?;
let val = br.get_bits(1);
skipped_bits += 1;
if val == 1 || skipped_bits > 8 {
break;
Expand All @@ -208,11 +208,11 @@ fn decompress_literals(
//if more than 7 bits are 0, this is not the correct end of the bitstream. Either a bug or corrupted data
return Err(DecompressLiteralsError::ExtraPadding { skipped_bits });
}
decoder.init_state(&mut br)?;
decoder.init_state(&mut br);

while br.bits_remaining() > -(scratch.table.max_num_bits as isize) {
target.push(decoder.decode_symbol());
decoder.next_state(&mut br)?;
decoder.next_state(&mut br);
}
if br.bits_remaining() != -(scratch.table.max_num_bits as isize) {
return Err(DecompressLiteralsError::BitstreamReadMismatch {
Expand All @@ -230,7 +230,7 @@ fn decompress_literals(
let mut br = BitReaderReversed::new(source);
let mut skipped_bits = 0;
loop {
let val = br.get_bits(1)?;
let val = br.get_bits(1);
skipped_bits += 1;
if val == 1 || skipped_bits > 8 {
break;
Expand All @@ -240,10 +240,10 @@ fn decompress_literals(
//if more than 7 bits are 0, this is not the correct end of the bitstream. Either a bug or corrupted data
return Err(DecompressLiteralsError::ExtraPadding { skipped_bits });
}
decoder.init_state(&mut br)?;
decoder.init_state(&mut br);
while br.bits_remaining() > -(scratch.table.max_num_bits as isize) {
target.push(decoder.decode_symbol());
decoder.next_state(&mut br)?;
decoder.next_state(&mut br);
}
bytes_read += source.len() as u32;
}
Expand Down
16 changes: 10 additions & 6 deletions src/decoding/scratch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ use crate::fse::FSETable;
use crate::huff0::HuffmanTable;
use alloc::vec::Vec;

use crate::blocks::sequence_section::{
MAX_LITERAL_LENGTH_CODE, MAX_MATCH_LENGTH_CODE, MAX_OFFSET_CODE,
};

/// A block level decoding buffer.
pub struct DecoderScratch {
/// The decoder used for Huffman blocks.
Expand All @@ -29,11 +33,11 @@ impl DecoderScratch {
table: HuffmanTable::new(),
},
fse: FSEScratch {
offsets: FSETable::new(),
offsets: FSETable::new(MAX_OFFSET_CODE),
of_rle: None,
literal_lengths: FSETable::new(),
literal_lengths: FSETable::new(MAX_LITERAL_LENGTH_CODE),
ll_rle: None,
match_lengths: FSETable::new(),
match_lengths: FSETable::new(MAX_MATCH_LENGTH_CODE),
ml_rle: None,
},
buffer: DecodeBuffer::new(window_size),
Expand Down Expand Up @@ -104,11 +108,11 @@ pub struct FSEScratch {
impl FSEScratch {
pub fn new() -> FSEScratch {
FSEScratch {
offsets: FSETable::new(),
offsets: FSETable::new(MAX_OFFSET_CODE),
of_rle: None,
literal_lengths: FSETable::new(),
literal_lengths: FSETable::new(MAX_LITERAL_LENGTH_CODE),
ll_rle: None,
match_lengths: FSETable::new(),
match_lengths: FSETable::new(MAX_MATCH_LENGTH_CODE),
ml_rle: None,
}
}
Expand Down
40 changes: 26 additions & 14 deletions src/decoding/sequence_section_decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ use super::super::blocks::sequence_section::Sequence;
use super::super::blocks::sequence_section::SequencesHeader;
use super::bit_reader_reverse::{BitReaderReversed, GetBitsError};
use super::scratch::FSEScratch;
use crate::blocks::sequence_section::{
MAX_LITERAL_LENGTH_CODE, MAX_MATCH_LENGTH_CODE, MAX_OFFSET_CODE,
};
use crate::fse::{FSEDecoder, FSEDecoderError, FSETableError};
use alloc::vec::Vec;

Expand Down Expand Up @@ -116,7 +119,7 @@ pub fn decode_sequences(
//skip the 0 padding at the end of the last byte of the bit stream and throw away the first 1 found
let mut skipped_bits = 0;
loop {
let val = br.get_bits(1)?;
let val = br.get_bits(1);
skipped_bits += 1;
if val == 1 || skipped_bits > 8 {
break;
Expand Down Expand Up @@ -189,13 +192,13 @@ fn decode_sequences_with_rle(
//println!("ml Code: {}", ml_value);
//println!("");

if of_code >= 32 {
if of_code > MAX_OFFSET_CODE {
return Err(DecodeSequenceError::UnsupportedOffset {
offset_code: of_code,
});
}

let (obits, ml_add, ll_add) = br.get_bits_triple(of_code, ml_num_bits, ll_num_bits)?;
let (obits, ml_add, ll_add) = br.get_bits_triple(of_code, ml_num_bits, ll_num_bits);
let offset = obits as u32 + (1u32 << of_code);

if offset == 0 {
Expand All @@ -215,13 +218,13 @@ fn decode_sequences_with_rle(
// br.bits_remaining() / 8,
//);
if scratch.ll_rle.is_none() {
ll_dec.update_state(br)?;
ll_dec.update_state(br);
}
if scratch.ml_rle.is_none() {
ml_dec.update_state(br)?;
ml_dec.update_state(br);
}
if scratch.of_rle.is_none() {
of_dec.update_state(br)?;
of_dec.update_state(br);
}
}

Expand Down Expand Up @@ -264,13 +267,13 @@ fn decode_sequences_without_rle(
let (ll_value, ll_num_bits) = lookup_ll_code(ll_code);
let (ml_value, ml_num_bits) = lookup_ml_code(ml_code);

if of_code >= 32 {
if of_code > MAX_OFFSET_CODE {
return Err(DecodeSequenceError::UnsupportedOffset {
offset_code: of_code,
});
}

let (obits, ml_add, ll_add) = br.get_bits_triple(of_code, ml_num_bits, ll_num_bits)?;
let (obits, ml_add, ll_add) = br.get_bits_triple(of_code, ml_num_bits, ll_num_bits);
let offset = obits as u32 + (1u32 << of_code);

if offset == 0 {
Expand All @@ -289,9 +292,9 @@ fn decode_sequences_without_rle(
// br.bits_remaining(),
// br.bits_remaining() / 8,
//);
ll_dec.update_state(br)?;
ml_dec.update_state(br)?;
of_dec.update_state(br)?;
ll_dec.update_state(br);
ml_dec.update_state(br);
of_dec.update_state(br);
}

if br.bits_remaining() < 0 {
Expand Down Expand Up @@ -335,7 +338,7 @@ fn lookup_ll_code(code: u8) -> (u32, u8) {
33 => (16384, 14),
34 => (32768, 15),
35 => (65536, 16),
_ => (0, 255),
_ => unreachable!("Illegal literal length code was: {}", code),
}
}

Expand Down Expand Up @@ -367,7 +370,7 @@ fn lookup_ml_code(code: u8) -> (u32, u8) {
50 => (16387, 14),
51 => (32771, 15),
52 => (65539, 16),
_ => (0, 255),
_ => unreachable!("Illegal match length code was: {}", code),
}
}

Expand Down Expand Up @@ -405,6 +408,9 @@ fn maybe_update_fse_tables(
return Err(DecodeSequenceError::MissingByteForRleLlTable);
}
bytes_read += 1;
if source[0] > MAX_LITERAL_LENGTH_CODE {
return Err(DecodeSequenceError::MissingByteForRleMlTable);
}
scratch.ll_rle = Some(source[0]);
}
ModeType::Predefined => {
Expand Down Expand Up @@ -437,6 +443,9 @@ fn maybe_update_fse_tables(
return Err(DecodeSequenceError::MissingByteForRleOfTable);
}
bytes_read += 1;
if of_source[0] > MAX_OFFSET_CODE {
return Err(DecodeSequenceError::MissingByteForRleMlTable);
}
scratch.of_rle = Some(of_source[0]);
}
ModeType::Predefined => {
Expand Down Expand Up @@ -469,6 +478,9 @@ fn maybe_update_fse_tables(
return Err(DecodeSequenceError::MissingByteForRleMlTable);
}
bytes_read += 1;
if ml_source[0] > MAX_MATCH_LENGTH_CODE {
return Err(DecodeSequenceError::MissingByteForRleMlTable);
}
scratch.ml_rle = Some(ml_source[0]);
}
ModeType::Predefined => {
Expand Down Expand Up @@ -522,7 +534,7 @@ const OFFSET_DEFAULT_DISTRIBUTION: [i32; 29] = [

#[test]
fn test_ll_default() {
let mut table = crate::fse::FSETable::new();
let mut table = crate::fse::FSETable::new(MAX_LITERAL_LENGTH_CODE);
table
.build_from_probabilities(
LL_DEFAULT_ACC_LOG,
Expand Down
Loading

0 comments on commit 944b391

Please sign in to comment.