diff --git a/src/encoder.rs b/src/encoder.rs index b0d7268932..29171df6ba 100644 --- a/src/encoder.rs +++ b/src/encoder.rs @@ -1086,7 +1086,7 @@ pub fn encode_tx_block( alpha: i16, rdo_type: RDOType, need_recon_pixel: bool, -) -> (bool, i64) { +) -> (bool, ScaledDistortion) { let qidx = get_qidx(fi, ts, cw, tile_bo); assert_ne!(qidx, 0); // lossless is not yet supported let PlaneConfig { xdec, ydec, .. } = ts.input.planes[p].cfg; @@ -1129,7 +1129,7 @@ pub fn encode_tx_block( } if skip { - return (false, -1); + return (false, ScaledDistortion::zero()); } let mut residual_storage: AlignedArray<[i16; 64 * 64]> = @@ -1195,7 +1195,7 @@ pub fn encode_tx_block( fi.ac_delta_q[p], ); - let mut tx_dist: i64 = -1; + let mut tx_dist: u64 = 0; if !fi.use_tx_domain_distortion || need_recon_pixel { inverse_transform_add( @@ -1215,29 +1215,24 @@ pub fn encode_tx_block( let c = *a as i32 - *b as i32; (c * c) as u64 }) - .sum::() as i64; + .sum::(); let tx_dist_scale_bits = 2 * (3 - get_log_tx_scale(tx_size)); let tx_dist_scale_rounding_offset = 1 << (tx_dist_scale_bits - 1); tx_dist = (tx_dist + tx_dist_scale_rounding_offset) >> tx_dist_scale_bits; } if fi.config.train_rdo { - ts.rdo.add_rate( - fi.base_q_idx, - tx_size, - tx_dist as u64, - cost_coeffs as u64, - ); + ts.rdo.add_rate(fi.base_q_idx, tx_size, tx_dist, cost_coeffs as u64); } if rdo_type == RDOType::TxDistEstRate { // look up rate and distortion in table - let estimated_rate = estimate_rate(fi.base_q_idx, tx_size, tx_dist as u64); + let estimated_rate = estimate_rate(fi.base_q_idx, tx_size, tx_dist); w.add_bits_frac(estimated_rate as u32); } let bias = compute_distortion_bias(fi, ts.to_frame_block_offset(tile_bo), bsize); - (has_coeff, (tx_dist as f64 * bias * fi.dist_scale[p]) as i64) + (has_coeff, RawDistortion::new(tx_dist) * bias * fi.dist_scale[p]) } pub fn motion_compensate( @@ -1491,7 +1486,7 @@ pub fn encode_block_post_cdef( skip: bool, cfl: CFLParams, tx_size: TxSize, tx_type: TxType, mode_context: usize, mv_stack: &[CandidateMV], rdo_type: RDOType, need_recon_pixel: bool, record_stats: bool, -) -> (bool, i64) { +) -> (bool, ScaledDistortion) { let is_inter = !luma_mode.is_intra(); if is_inter { assert!(luma_mode == chroma_mode); @@ -1764,7 +1759,7 @@ pub fn write_tx_blocks( chroma_mode: PredictionMode, tile_bo: TileBlockOffset, bsize: BlockSize, tx_size: TxSize, tx_type: TxType, skip: bool, cfl: CFLParams, luma_only: bool, rdo_type: RDOType, need_recon_pixel: bool, -) -> (bool, i64) { +) -> (bool, ScaledDistortion) { let bw = bsize.width_mi() / tx_size.width_mi(); let bh = bsize.height_mi() / tx_size.height_mi(); let qidx = get_qidx(fi, ts, cw, tile_bo); @@ -1772,7 +1767,7 @@ pub fn write_tx_blocks( let PlaneConfig { xdec, ydec, .. } = ts.input.planes[1].cfg; let mut ac: AlignedArray<[i16; 32 * 32]> = AlignedArray::uninitialized(); let mut partition_has_coeff: bool = false; - let mut tx_dist: i64 = 0; + let mut tx_dist = ScaledDistortion::zero(); let do_chroma = has_chroma(tile_bo, bsize, xdec, ydec); ts.qc.update( @@ -1814,9 +1809,6 @@ pub fn write_tx_blocks( need_recon_pixel, ); partition_has_coeff |= has_coeff; - assert!( - !fi.use_tx_domain_distortion || need_recon_pixel || skip || dist >= 0 - ); tx_dist += dist; } } @@ -1900,12 +1892,6 @@ pub fn write_tx_blocks( need_recon_pixel, ); partition_has_coeff |= has_coeff; - assert!( - !fi.use_tx_domain_distortion - || need_recon_pixel - || skip - || dist >= 0 - ); tx_dist += dist; } } @@ -1923,9 +1909,9 @@ pub fn write_tx_tree( tile_bo: TileBlockOffset, bsize: BlockSize, tx_size: TxSize, tx_type: TxType, skip: bool, luma_only: bool, rdo_type: RDOType, need_recon_pixel: bool, -) -> (bool, i64) { +) -> (bool, ScaledDistortion) { if skip { - return (false, -1); + return (false, ScaledDistortion::zero()); } let bw = bsize.width_mi() / tx_size.width_mi(); let bh = bsize.height_mi() / tx_size.height_mi(); @@ -1933,7 +1919,6 @@ pub fn write_tx_tree( let PlaneConfig { xdec, ydec, .. } = ts.input.planes[1].cfg; let ac = &[0i16; 0]; - let mut tx_dist: i64 = 0; let mut partition_has_coeff: bool = false; ts.qc.update( @@ -1968,10 +1953,7 @@ pub fn write_tx_tree( need_recon_pixel, ); partition_has_coeff |= has_coeff; - assert!( - !fi.use_tx_domain_distortion || need_recon_pixel || skip || dist >= 0 - ); - tx_dist += dist; + let mut tx_dist = dist; if luma_only { return (partition_has_coeff, tx_dist); @@ -2044,12 +2026,6 @@ pub fn write_tx_tree( need_recon_pixel, ); partition_has_coeff |= has_coeff; - assert!( - !fi.use_tx_domain_distortion - || need_recon_pixel - || skip - || dist >= 0 - ); tx_dist += dist; } } @@ -2176,7 +2152,7 @@ fn encode_partition_bottomup( let w: &mut W = if cw.bc.cdef_coded { w_post_cdef } else { w_pre_cdef }; let tell = w.tell_frac(); cw.write_partition(w, tile_bo, PartitionType::PARTITION_NONE, bsize); - compute_rd_cost(fi, w.tell_frac() - tell, 0) + compute_rd_cost(fi, w.tell_frac() - tell, ScaledDistortion::zero()) } else { 0.0 }; @@ -2276,7 +2252,8 @@ fn encode_partition_bottomup( if cw.bc.cdef_coded { w_post_cdef } else { w_pre_cdef }; let tell = w.tell_frac(); cw.write_partition(w, tile_bo, partition, bsize); - rd_cost = compute_rd_cost(fi, w.tell_frac() - tell, 0); + rd_cost = + compute_rd_cost(fi, w.tell_frac() - tell, ScaledDistortion::zero()); } let four_partitions = [ diff --git a/src/rdo.rs b/src/rdo.rs index 26704812e4..2a30824226 100644 --- a/src/rdo.rs +++ b/src/rdo.rs @@ -197,7 +197,7 @@ pub fn estimate_rate(qindex: u8, ts: TxSize, fast_distortion: u64) -> u64 { #[allow(unused)] fn cdef_dist_wxh_8x8( src1: &PlaneRegion<'_, T>, src2: &PlaneRegion<'_, T>, bit_depth: usize, -) -> u64 { +) -> RawDistortion { debug_assert!(src1.plane_cfg.xdec == 0); debug_assert!(src1.plane_cfg.ydec == 0); debug_assert!(src2.plane_cfg.xdec == 0); @@ -228,14 +228,14 @@ fn cdef_dist_wxh_8x8( let ssim_boost = (4033_f64 / 16_384_f64) * (svar + dvar + (16_384 << (2 * coeff_shift)) as f64) / f64::sqrt((16_265_089u64 << (4 * coeff_shift)) as f64 + svar * dvar); - (sse * ssim_boost + 0.5_f64) as u64 + RawDistortion::new((sse * ssim_boost + 0.5_f64) as u64) } #[allow(unused)] fn cdef_dist_wxh f64>( src1: &PlaneRegion<'_, T>, src2: &PlaneRegion<'_, T>, w: usize, h: usize, bit_depth: usize, compute_bias: F, -) -> u64 { +) -> Distortion { assert!(w & 0x7 == 0); assert!(h & 0x7 == 0); debug_assert!(src1.plane_cfg.xdec == 0); @@ -243,7 +243,7 @@ fn cdef_dist_wxh f64>( debug_assert!(src2.plane_cfg.xdec == 0); debug_assert!(src2.plane_cfg.ydec == 0); - let mut sum: u64 = 0; + let mut sum = Distortion::zero(); for j in 0isize..h as isize / 8 { for i in 0isize..w as isize / 8 { let area = Area::StartingAt { x: i * 8, y: j * 8 }; @@ -255,8 +255,7 @@ fn cdef_dist_wxh f64>( // cdef is always called on non-subsampled planes, so BLOCK_8X8 is // correct here. - let bias = compute_bias(area, BlockSize::BLOCK_8X8); - sum += (value as f64 * bias) as u64; + sum += value * compute_bias(area, BlockSize::BLOCK_8X8); } } sum @@ -266,7 +265,7 @@ fn cdef_dist_wxh f64>( pub fn sse_wxh f64>( src1: &PlaneRegion<'_, T>, src2: &PlaneRegion<'_, T>, w: usize, h: usize, compute_bias: F, -) -> u64 { +) -> Distortion { assert!(w & (MI_SIZE - 1) == 0); assert!(h & (MI_SIZE - 1) == 0); @@ -278,7 +277,7 @@ pub fn sse_wxh f64>( let block_w = imp_block_w >> src1.plane_cfg.xdec; let block_h = imp_block_h >> src1.plane_cfg.ydec; - let mut sse: u64 = 0; + let mut sse = Distortion::zero(); for block_y in 0..h / block_h { for block_x in 0..w / block_w { let mut value = 0; @@ -308,7 +307,7 @@ pub fn sse_wxh f64>( }, imp_bsize, ); - sse += (value as f64 * bias) as u64; + sse += RawDistortion::new(value) * bias; } } sse @@ -318,7 +317,7 @@ pub fn sse_wxh f64>( fn compute_distortion( fi: &FrameInvariants, ts: &TileStateMut<'_, T>, bsize: BlockSize, is_chroma_block: bool, tile_bo: TileBlockOffset, luma_only: bool, -) -> u64 { +) -> ScaledDistortion { let area = Area::BlockStartingAt { bo: tile_bo.0 }; let input_region = ts.input_tile.planes[0].subregion(area); let rec_region = ts.rec.planes[0].subregion(area); @@ -352,9 +351,7 @@ fn compute_distortion( ) }, ), - }; - - distortion = (fi.dist_scale[0] * distortion as f64) as u64; + } * fi.dist_scale[0]; if !luma_only { let PlaneConfig { xdec, ydec, .. } = ts.input.planes[1].cfg; @@ -373,7 +370,7 @@ fn compute_distortion( for p in 1..3 { let input_region = ts.input_tile.planes[p].subregion(area); let rec_region = ts.rec.planes[p].subregion(area); - distortion += (sse_wxh( + distortion += sse_wxh( &input_region, &rec_region, w_uv, @@ -385,8 +382,7 @@ fn compute_distortion( bsize, ) }, - ) as f64 - * fi.dist_scale[p]) as u64; + ) * fi.dist_scale[p]; } }; } @@ -396,15 +392,15 @@ fn compute_distortion( // Compute the transform-domain distortion for an encode fn compute_tx_distortion( fi: &FrameInvariants, ts: &TileStateMut<'_, T>, bsize: BlockSize, - is_chroma_block: bool, tile_bo: TileBlockOffset, tx_dist: i64, skip: bool, - luma_only: bool, -) -> u64 { + is_chroma_block: bool, tile_bo: TileBlockOffset, tx_dist: ScaledDistortion, + skip: bool, luma_only: bool, +) -> ScaledDistortion { assert!(fi.config.tune == Tune::Psnr); let area = Area::BlockStartingAt { bo: tile_bo.0 }; let input_region = ts.input_tile.planes[0].subregion(area); let rec_region = ts.rec.planes[0].subregion(area); let mut distortion = if skip { - (sse_wxh( + sse_wxh( &input_region, &rec_region, bsize.width(), @@ -416,11 +412,9 @@ fn compute_tx_distortion( bsize, ) }, - ) as f64 - * fi.dist_scale[0]) as u64 + ) * fi.dist_scale[0] } else { - assert!(tx_dist >= 0); - tx_dist as u64 + tx_dist }; if !luma_only && skip { @@ -440,7 +434,7 @@ fn compute_tx_distortion( for p in 1..3 { let input_region = ts.input_tile.planes[p].subregion(area); let rec_region = ts.rec.planes[p].subregion(area); - distortion += (sse_wxh( + distortion += sse_wxh( &input_region, &rec_region, w_uv, @@ -452,8 +446,7 @@ fn compute_tx_distortion( bsize, ) }, - ) as f64 - * fi.dist_scale[p]) as u64; + ) * fi.dist_scale[p]; } } } @@ -493,11 +486,64 @@ pub fn compute_distortion_bias( bias } +#[repr(transparent)] +pub struct RawDistortion(u64); + +#[repr(transparent)] +pub struct Distortion(u64); + +#[repr(transparent)] +pub struct ScaledDistortion(u64); + +impl RawDistortion { + pub fn new(dist: u64) -> Self { + Self(dist) + } +} + +impl std::ops::Mul for RawDistortion { + type Output = Distortion; + fn mul(self, rhs: f64) -> Distortion { + Distortion((self.0 as f64 * rhs) as u64) + } +} + +impl Distortion { + pub fn zero() -> Self { + Self(0) + } +} + +impl std::ops::Mul for Distortion { + type Output = ScaledDistortion; + fn mul(self, rhs: f64) -> ScaledDistortion { + ScaledDistortion((self.0 as f64 * rhs) as u64) + } +} + +impl std::ops::AddAssign for Distortion { + fn add_assign(&mut self, other: Self) { + self.0 += other.0; + } +} + +impl ScaledDistortion { + pub fn zero() -> Self { + Self(0) + } +} + +impl std::ops::AddAssign for ScaledDistortion { + fn add_assign(&mut self, other: Self) { + self.0 += other.0; + } +} + pub fn compute_rd_cost( - fi: &FrameInvariants, rate: u32, distortion: u64, + fi: &FrameInvariants, rate: u32, distortion: ScaledDistortion, ) -> f64 { let rate_in_bits = (rate as f64) / ((1 << OD_BITRES) as f64); - distortion as f64 + fi.lambda * rate_in_bits + distortion.0 as f64 + fi.lambda * rate_in_bits } pub fn rdo_tx_size_type( @@ -712,6 +758,7 @@ fn luma_chroma_mode_rdo( } else { compute_distortion(fi, ts, bsize, is_chroma_block, tile_bo, false) }; + let is_zero_dist = distortion.0 == 0; let rd = compute_rd_cost(fi, rate, distortion); if rd < best.rd { //if rd < best.rd || luma_mode == PredictionMode::NEW_NEWMV { @@ -725,7 +772,7 @@ fn luma_chroma_mode_rdo( best.tx_size = tx_size; best.tx_type = tx_type; best.sidx = sidx; - zero_distortion = distortion == 0; + zero_distortion = is_zero_dist; } cw.rollback(cw_checkpoint); @@ -1263,6 +1310,7 @@ pub fn rdo_cfl_alpha( uv_tx_size.height(), |_, _| 1., // We're not doing RDO here. ) + .0 }; let mut best = (alpha_cost(0), 0); let mut count = 2; @@ -1548,7 +1596,11 @@ pub fn rdo_partition_decision( if cw.bc.cdef_coded { w_post_cdef } else { w_pre_cdef }; let tell = w.tell_frac(); cw.write_partition(w, tile_bo, partition, bsize); - cost = compute_rd_cost(fi, w.tell_frac() - tell, 0); + cost = compute_rd_cost( + fi, + w.tell_frac() - tell, + ScaledDistortion::zero(), + ); } let mut rd_cost_sum = 0.0; @@ -1627,14 +1679,14 @@ fn rdo_loop_plane_error( sbo: TileSuperBlockOffset, tile_sbo: TileSuperBlockOffset, sb_w: usize, sb_h: usize, fi: &FrameInvariants, ts: &TileStateMut<'_, T>, blocks: &TileBlocks<'_>, test: &Frame, pli: usize, -) -> u64 { +) -> ScaledDistortion { let sb_w_blocks = if fi.sequence.use_128x128_superblock { 16 } else { 8 } * sb_w; let sb_h_blocks = if fi.sequence.use_128x128_superblock { 16 } else { 8 } * sb_h; // Each direction block is 8x8 in y, potentially smaller if subsampled in chroma // accumulating in-frame and unpadded - let mut err: u64 = 0; + let mut err = Distortion::zero(); for by in 0..sb_h_blocks { for bx in 0..sb_w_blocks { let bo = tile_sbo.block_offset(bx << 1, by << 1); @@ -1659,16 +1711,15 @@ fn rdo_loop_plane_error( BlockSize::BLOCK_8X8, ); err += if pli == 0 { - (cdef_dist_wxh_8x8(&in_region, &test_region, fi.sequence.bit_depth) - as f64 - * bias) as u64 + cdef_dist_wxh_8x8(&in_region, &test_region, fi.sequence.bit_depth) + * bias } else { sse_wxh(&in_region, &test_region, 8 >> xdec, 8 >> ydec, |_, _| bias) }; } } } - (err as f64 * fi.dist_scale[pli]) as u64 + err * fi.dist_scale[pli] } // Passed in a superblock offset representing the upper left corner of @@ -1791,7 +1842,7 @@ pub fn rdo_loop_decision( let mut best_new_index = -1i8; for cdef_index in 0..(1 << fi.cdef_bits) { - let mut err = 0; + let mut err = ScaledDistortion::zero(); let mut rate = 0; cdef_filter_superblock( fi,