Skip to content

Commit

Permalink
Rollup merge of #73949 - wesleywiser:simplify_try_fixes, r=oli-obk
Browse files Browse the repository at this point in the history
[mir-opt] Fix mis-optimization and other issues with the SimplifyArmIdentity pass

This does not yet attempt re-enabling the pass, but it does resolve a number of issues with the pass.

r? @oli-obk

I believe this closes #73223.
  • Loading branch information
Manishearth committed Jul 4, 2020
2 parents 9d0ca38 + e16d6a6 commit 60cad20
Show file tree
Hide file tree
Showing 12 changed files with 1,411 additions and 35 deletions.
12 changes: 12 additions & 0 deletions src/librustc_middle/mir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,18 @@ impl<'tcx> Body<'tcx> {
(&mut self.basic_blocks, &mut self.local_decls)
}

#[inline]
pub fn basic_blocks_local_decls_mut_and_var_debug_info(
&mut self,
) -> (
&mut IndexVec<BasicBlock, BasicBlockData<'tcx>>,
&mut LocalDecls<'tcx>,
&mut Vec<VarDebugInfo<'tcx>>,
) {
self.predecessor_cache.invalidate();
(&mut self.basic_blocks, &mut self.local_decls, &mut self.var_debug_info)
}

/// Returns `true` if a cycle exists in the control-flow graph that is reachable from the
/// `START_BLOCK`.
pub fn is_cfg_cyclic(&self) -> bool {
Expand Down
115 changes: 106 additions & 9 deletions src/librustc_mir/transform/simplify_try.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@

use crate::transform::{simplify, MirPass, MirSource};
use itertools::Itertools as _;
use rustc_index::vec::IndexVec;
use rustc_index::{bit_set::BitSet, vec::IndexVec};
use rustc_middle::mir::visit::{NonUseContext, PlaceContext, Visitor};
use rustc_middle::mir::*;
use rustc_middle::ty::{Ty, TyCtxt};
use rustc_middle::ty::{List, Ty, TyCtxt};
use rustc_target::abi::VariantIdx;
use std::iter::{Enumerate, Peekable};
use std::slice::Iter;
Expand Down Expand Up @@ -73,9 +74,20 @@ struct ArmIdentityInfo<'tcx> {

/// The statements that should be removed (turned into nops)
stmts_to_remove: Vec<usize>,

/// Indices of debug variables that need to be adjusted to point to
// `{local_0}.{dbg_projection}`.
dbg_info_to_adjust: Vec<usize>,

/// The projection used to rewrite debug info.
dbg_projection: &'tcx List<PlaceElem<'tcx>>,
}

fn get_arm_identity_info<'a, 'tcx>(stmts: &'a [Statement<'tcx>]) -> Option<ArmIdentityInfo<'tcx>> {
fn get_arm_identity_info<'a, 'tcx>(
stmts: &'a [Statement<'tcx>],
locals_count: usize,
debug_info: &'a [VarDebugInfo<'tcx>],
) -> Option<ArmIdentityInfo<'tcx>> {
// This can't possibly match unless there are at least 3 statements in the block
// so fail fast on tiny blocks.
if stmts.len() < 3 {
Expand Down Expand Up @@ -187,7 +199,7 @@ fn get_arm_identity_info<'a, 'tcx>(stmts: &'a [Statement<'tcx>]) -> Option<ArmId
try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts);

let (get_variant_field_stmt, stmt) = stmt_iter.next()?;
let (local_tmp_s0, local_1, vf_s0) = match_get_variant_field(stmt)?;
let (local_tmp_s0, local_1, vf_s0, dbg_projection) = match_get_variant_field(stmt)?;

try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts);

Expand Down Expand Up @@ -228,6 +240,19 @@ fn get_arm_identity_info<'a, 'tcx>(stmts: &'a [Statement<'tcx>]) -> Option<ArmId
let stmt_to_overwrite =
nop_stmts.iter().find(|stmt_idx| live_idx < **stmt_idx && **stmt_idx < dead_idx);

let mut tmp_assigned_vars = BitSet::new_empty(locals_count);
for (l, r) in &tmp_assigns {
tmp_assigned_vars.insert(*l);
tmp_assigned_vars.insert(*r);
}

let mut dbg_info_to_adjust = Vec::new();
for (i, var_info) in debug_info.iter().enumerate() {
if tmp_assigned_vars.contains(var_info.place.local) {
dbg_info_to_adjust.push(i);
}
}

Some(ArmIdentityInfo {
local_temp_0: local_tmp_s0,
local_1,
Expand All @@ -243,12 +268,16 @@ fn get_arm_identity_info<'a, 'tcx>(stmts: &'a [Statement<'tcx>]) -> Option<ArmId
source_info: discr_stmt_source_info,
storage_stmts,
stmts_to_remove: nop_stmts,
dbg_info_to_adjust,
dbg_projection,
})
}

fn optimization_applies<'tcx>(
opt_info: &ArmIdentityInfo<'tcx>,
local_decls: &IndexVec<Local, LocalDecl<'tcx>>,
local_uses: &IndexVec<Local, usize>,
var_debug_info: &[VarDebugInfo<'tcx>],
) -> bool {
trace!("testing if optimization applies...");

Expand All @@ -273,6 +302,7 @@ fn optimization_applies<'tcx>(
// Verify the assigment chain consists of the form b = a; c = b; d = c; etc...
if opt_info.field_tmp_assignments.is_empty() {
trace!("NO: no assignments found");
return false;
}
let mut last_assigned_to = opt_info.field_tmp_assignments[0].1;
let source_local = last_assigned_to;
Expand All @@ -285,6 +315,35 @@ fn optimization_applies<'tcx>(
last_assigned_to = *l;
}

// Check that the first and last used locals are only used twice
// since they are of the form:
//
// ```
// _first = ((_x as Variant).n: ty);
// _n = _first;
// ...
// ((_y as Variant).n: ty) = _n;
// discriminant(_y) = z;
// ```
for (l, r) in &opt_info.field_tmp_assignments {
if local_uses[*l] != 2 {
warn!("NO: FAILED assignment chain local {:?} was used more than twice", l);
return false;
} else if local_uses[*r] != 2 {
warn!("NO: FAILED assignment chain local {:?} was used more than twice", r);
return false;
}
}

// Check that debug info only points to full Locals and not projections.
for dbg_idx in &opt_info.dbg_info_to_adjust {
let dbg_info = &var_debug_info[*dbg_idx];
if !dbg_info.place.projection.is_empty() {
trace!("NO: debug info for {:?} had a projection {:?}", dbg_info.name, dbg_info.place);
return false;
}
}

if source_local != opt_info.local_temp_0 {
trace!(
"NO: start of assignment chain does not match enum variant temp: {:?} != {:?}",
Expand Down Expand Up @@ -312,11 +371,15 @@ impl<'tcx> MirPass<'tcx> for SimplifyArmIdentity {
}

trace!("running SimplifyArmIdentity on {:?}", source);
let (basic_blocks, local_decls) = body.basic_blocks_and_local_decls_mut();
let local_uses = LocalUseCounter::get_local_uses(body);
let (basic_blocks, local_decls, debug_info) =
body.basic_blocks_local_decls_mut_and_var_debug_info();
for bb in basic_blocks {
if let Some(opt_info) = get_arm_identity_info(&bb.statements) {
if let Some(opt_info) =
get_arm_identity_info(&bb.statements, local_decls.len(), debug_info)
{
trace!("got opt_info = {:#?}", opt_info);
if !optimization_applies(&opt_info, local_decls) {
if !optimization_applies(&opt_info, local_decls, &local_uses, &debug_info) {
debug!("optimization skipped for {:?}", source);
continue;
}
Expand Down Expand Up @@ -352,23 +415,57 @@ impl<'tcx> MirPass<'tcx> for SimplifyArmIdentity {

bb.statements.retain(|stmt| stmt.kind != StatementKind::Nop);

// Fix the debug info to point to the right local
for dbg_index in opt_info.dbg_info_to_adjust {
let dbg_info = &mut debug_info[dbg_index];
assert!(dbg_info.place.projection.is_empty());
dbg_info.place.local = opt_info.local_0;
dbg_info.place.projection = opt_info.dbg_projection;
}

trace!("block is now {:?}", bb.statements);
}
}
}
}

struct LocalUseCounter {
local_uses: IndexVec<Local, usize>,
}

impl LocalUseCounter {
fn get_local_uses<'tcx>(body: &Body<'tcx>) -> IndexVec<Local, usize> {
let mut counter = LocalUseCounter { local_uses: IndexVec::from_elem(0, &body.local_decls) };
counter.visit_body(body);
counter.local_uses
}
}

impl<'tcx> Visitor<'tcx> for LocalUseCounter {
fn visit_local(&mut self, local: &Local, context: PlaceContext, _location: Location) {
if context.is_storage_marker()
|| context == PlaceContext::NonUse(NonUseContext::VarDebugInfo)
{
return;
}

self.local_uses[*local] += 1;
}
}

/// Match on:
/// ```rust
/// _LOCAL_INTO = ((_LOCAL_FROM as Variant).FIELD: TY);
/// ```
fn match_get_variant_field<'tcx>(stmt: &Statement<'tcx>) -> Option<(Local, Local, VarField<'tcx>)> {
fn match_get_variant_field<'tcx>(
stmt: &Statement<'tcx>,
) -> Option<(Local, Local, VarField<'tcx>, &'tcx List<PlaceElem<'tcx>>)> {
match &stmt.kind {
StatementKind::Assign(box (place_into, rvalue_from)) => match rvalue_from {
Rvalue::Use(Operand::Copy(pf) | Operand::Move(pf)) => {
let local_into = place_into.as_local()?;
let (local_from, vf) = match_variant_field_place(*pf)?;
Some((local_into, local_from, vf))
Some((local_into, local_from, vf, pf.projection))
}
_ => None,
},
Expand Down
13 changes: 13 additions & 0 deletions src/test/mir-opt/issue-73223.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
fn main() {
let split = match Some(1) {
Some(v) => v,
None => return,
};

let _prev = Some(split);
assert_eq!(split, 1);
}

// EMIT_MIR_FOR_EACH_BIT_WIDTH
// EMIT_MIR rustc.main.SimplifyArmIdentity.diff
// EMIT_MIR rustc.main.PreCodegen.diff
Loading

0 comments on commit 60cad20

Please sign in to comment.