Skip to content

Commit

Permalink
Auto merge of rust-lang#14122 - lowr:patch/abort-macro-expansion-on-o…
Browse files Browse the repository at this point in the history
…verflow, r=Veykril

fix: Don't expand macros in the same expansion tree after overflow

This patch fixes 2 bugs:

- In `Expander::enter_expand_id()` (and in code paths it's called), we never check whether we've reached the recursion limit. Although it hasn't been reported as far as I'm aware, this may cause hangs or stack overflows if some malformed attribute macro is used on associated items.
- We keep expansion even when recursion limit is reached. Take the following for example:

  ```rust
  macro_rules! foo { () => {{ foo!(); foo!(); }} }
  fn main() { foo!(); }
  ```

  We keep expanding the first `foo!()` in each expansion and would reach the limit at some point, *after which* we would try expanding the second `foo!()` in each expansion until it hits the limit again. This will (by default) lead to ~2^128 expansions.

  This is essentially what's happening in rust-lang#14074. Unlike rustc, we don't just stop expanding macros when we fail as long as it produces some tokens so that we can provide completions and other services in incomplete macro calls.

This patch provides a method that takes care of recursion depths (`Expander::within_limit()`) and stops macro expansions in the whole macro expansion tree once it detects recursion depth overflow. To be honest, I'm not really satisfied with this fix because it can still be used in unintended ways to bypass overflow checks, and I'm still seeking ways such that misuses are caught by the compiler by leveraging types or something.

Fixes rust-lang#14074
  • Loading branch information
bors committed Feb 14, 2023
2 parents 3812951 + ae7e62c commit 2a57b01
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 49 deletions.
140 changes: 91 additions & 49 deletions crates/hir-def/src/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use la_arena::{Arena, ArenaMap};
use limit::Limit;
use profile::Count;
use rustc_hash::FxHashMap;
use syntax::{ast, AstPtr, SyntaxNodePtr};
use syntax::{ast, AstPtr, SyntaxNode, SyntaxNodePtr};

use crate::{
attr::Attrs,
Expand Down Expand Up @@ -51,7 +51,8 @@ pub struct Expander {
def_map: Arc<DefMap>,
current_file_id: HirFileId,
module: LocalModuleId,
recursion_limit: usize,
/// `recursion_depth == usize::MAX` indicates that the recursion limit has been reached.
recursion_depth: usize,
}

impl CfgExpander {
Expand Down Expand Up @@ -84,7 +85,7 @@ impl Expander {
def_map,
current_file_id,
module: module.local_id,
recursion_limit: 0,
recursion_depth: 0,
}
}

Expand All @@ -93,47 +94,52 @@ impl Expander {
db: &dyn DefDatabase,
macro_call: ast::MacroCall,
) -> Result<ExpandResult<Option<(Mark, T)>>, UnresolvedMacro> {
if self.recursion_limit(db).check(self.recursion_limit + 1).is_err() {
cov_mark::hit!(your_stack_belongs_to_me);
return Ok(ExpandResult::only_err(ExpandError::Other(
"reached recursion limit during macro expansion".into(),
)));
let mut unresolved_macro_err = None;

let result = self.within_limit(db, |this| {
let macro_call = InFile::new(this.current_file_id, &macro_call);

let resolver =
|path| this.resolve_path_as_macro(db, &path).map(|it| macro_id_to_def_id(db, it));

let mut err = None;
let call_id = match macro_call.as_call_id_with_errors(
db,
this.def_map.krate(),
resolver,
&mut |e| {
err.get_or_insert(e);
},
) {
Ok(call_id) => call_id,
Err(resolve_err) => {
unresolved_macro_err = Some(resolve_err);
return ExpandResult { value: None, err: None };
}
};
ExpandResult { value: call_id.ok(), err }
});

if let Some(err) = unresolved_macro_err {
Err(err)
} else {
Ok(result)
}

let macro_call = InFile::new(self.current_file_id, &macro_call);

let resolver =
|path| self.resolve_path_as_macro(db, &path).map(|it| macro_id_to_def_id(db, it));

let mut err = None;
let call_id =
macro_call.as_call_id_with_errors(db, self.def_map.krate(), resolver, &mut |e| {
err.get_or_insert(e);
})?;
let call_id = match call_id {
Ok(it) => it,
Err(_) => {
return Ok(ExpandResult { value: None, err });
}
};

Ok(self.enter_expand_inner(db, call_id, err))
}

pub fn enter_expand_id<T: ast::AstNode>(
&mut self,
db: &dyn DefDatabase,
call_id: MacroCallId,
) -> ExpandResult<Option<(Mark, T)>> {
self.enter_expand_inner(db, call_id, None)
self.within_limit(db, |_this| ExpandResult::ok(Some(call_id)))
}

fn enter_expand_inner<T: ast::AstNode>(
&mut self,
fn enter_expand_inner(
db: &dyn DefDatabase,
call_id: MacroCallId,
mut err: Option<ExpandError>,
) -> ExpandResult<Option<(Mark, T)>> {
) -> ExpandResult<Option<(HirFileId, SyntaxNode)>> {
if err.is_none() {
err = db.macro_expand_error(call_id);
}
Expand All @@ -154,29 +160,21 @@ impl Expander {
}
};

let node = match T::cast(raw_node) {
Some(it) => it,
None => {
// This can happen without being an error, so only forward previous errors.
return ExpandResult { value: None, err };
}
};

tracing::debug!("macro expansion {:#?}", node.syntax());

self.recursion_limit += 1;
let mark =
Mark { file_id: self.current_file_id, bomb: DropBomb::new("expansion mark dropped") };
self.cfg_expander.hygiene = Hygiene::new(db.upcast(), file_id);
self.current_file_id = file_id;

ExpandResult { value: Some((mark, node)), err }
ExpandResult { value: Some((file_id, raw_node)), err }
}

pub fn exit(&mut self, db: &dyn DefDatabase, mut mark: Mark) {
self.cfg_expander.hygiene = Hygiene::new(db.upcast(), mark.file_id);
self.current_file_id = mark.file_id;
self.recursion_limit -= 1;
if self.recursion_depth == usize::MAX {
// Recursion limit has been reached somewhere in the macro expansion tree. Reset the
// depth only when we get out of the tree.
if !self.current_file_id.is_macro() {
self.recursion_depth = 0;
}
} else {
self.recursion_depth -= 1;
}
mark.bomb.defuse();
}

Expand Down Expand Up @@ -215,6 +213,50 @@ impl Expander {
#[cfg(test)]
return Limit::new(std::cmp::min(32, limit));
}

fn within_limit<F, T: ast::AstNode>(
&mut self,
db: &dyn DefDatabase,
op: F,
) -> ExpandResult<Option<(Mark, T)>>
where
F: FnOnce(&mut Self) -> ExpandResult<Option<MacroCallId>>,
{
if self.recursion_depth == usize::MAX {
// Recursion limit has been reached somewhere in the macro expansion tree. We should
// stop expanding other macro calls in this tree, or else this may result in
// exponential number of macro expansions, leading to a hang.
//
// The overflow error should have been reported when it occurred (see the next branch),
// so don't return overflow error here to avoid diagnostics duplication.
cov_mark::hit!(overflow_but_not_me);
return ExpandResult::only_err(ExpandError::RecursionOverflowPosioned);
} else if self.recursion_limit(db).check(self.recursion_depth + 1).is_err() {
self.recursion_depth = usize::MAX;
cov_mark::hit!(your_stack_belongs_to_me);
return ExpandResult::only_err(ExpandError::Other(
"reached recursion limit during macro expansion".into(),
));
}

let ExpandResult { value, err } = op(self);
let Some(call_id) = value else {
return ExpandResult { value: None, err };
};

Self::enter_expand_inner(db, call_id, err).map(|value| {
value.and_then(|(new_file_id, node)| {
let node = T::cast(node)?;

self.recursion_depth += 1;
self.cfg_expander.hygiene = Hygiene::new(db.upcast(), new_file_id);
let old_file_id = std::mem::replace(&mut self.current_file_id, new_file_id);
let mark =
Mark { file_id: old_file_id, bomb: DropBomb::new("expansion mark dropped") };
Some((mark, node))
})
})
}
}

#[derive(Debug)]
Expand Down
6 changes: 6 additions & 0 deletions crates/hir-def/src/body/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,10 @@ impl ExprCollector<'_> {
krate: *krate,
});
}
Some(ExpandError::RecursionOverflowPosioned) => {
// Recursion limit has been reached in the macro expansion tree, but not in
// this very macro call. Don't add diagnostics to avoid duplication.
}
Some(err) => {
self.source_map.diagnostics.push(BodyDiagnostic::MacroError {
node: InFile::new(outer_file, syntax_ptr),
Expand All @@ -636,6 +640,8 @@ impl ExprCollector<'_> {

match res.value {
Some((mark, expansion)) => {
// Keep collecting even with expansion errors so we can provide completions and
// other services in incomplete macro expressions.
self.source_map.expansions.insert(macro_call_ptr, self.expander.current_file_id);
let prev_ast_id_map = mem::replace(
&mut self.ast_id_map,
Expand Down
13 changes: 13 additions & 0 deletions crates/hir-def/src/body/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,19 @@ fn main() { n_nuple!(1,2,3); }
);
}

#[test]
fn your_stack_belongs_to_me2() {
cov_mark::check!(overflow_but_not_me);
lower(
r#"
macro_rules! foo {
() => {{ foo!(); foo!(); }}
}
fn main() { foo!(); }
"#,
);
}

#[test]
fn recursion_limit() {
cov_mark::check!(your_stack_belongs_to_me);
Expand Down
4 changes: 4 additions & 0 deletions crates/hir-expand/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ pub type ExpandResult<T> = ValueResult<T, ExpandError>;
pub enum ExpandError {
UnresolvedProcMacro(CrateId),
Mbe(mbe::ExpandError),
RecursionOverflowPosioned,
Other(Box<str>),
}

Expand All @@ -69,6 +70,9 @@ impl fmt::Display for ExpandError {
match self {
ExpandError::UnresolvedProcMacro(_) => f.write_str("unresolved proc-macro"),
ExpandError::Mbe(it) => it.fmt(f),
ExpandError::RecursionOverflowPosioned => {
f.write_str("overflow expanding the original macro")
}
ExpandError::Other(it) => f.write_str(it),
}
}
Expand Down

0 comments on commit 2a57b01

Please sign in to comment.