Skip to content

Commit

Permalink
Fixes issue of missing type annotations on explicit returns. (#5575)
Browse files Browse the repository at this point in the history
## Description

Explicit returns did not have any type annotation causing the method
disambiguation to fail.

The solution was to add a new field to the `TypeCheckContext` called
`function_type_annotation`. This field is initialized only in the
type_check of function declarations. The field is used to set the
type_annotation one while doing the type check of explicit return
expressions.

With this change `TypeCheckUnification` was also simplified as the
unification is now done with `function_type_annotation`. The result was
a few functions that are no longer required to be removed.

Closes #5518

## Checklist

- [x] I have linked to any relevant issues.
- [x] I have commented my code, particularly in hard-to-understand
areas.
- [ ] I have updated the documentation where relevant (API docs, the
reference, and the Sway book).
- [x] I have added tests that prove my fix is effective or that my
feature works.
- [x] I have added (or requested a maintainer to add) the necessary
`Breaking*` or `New Feature` labels where relevant.
- [x] I have done my best to ensure that my PR adheres to [the Fuel Labs
Code Review
Standards](https://github.com/FuelLabs/rfcs/blob/master/text/code-standards/external-contributors.md).
- [x] I have requested a review from the relevant team or maintainers.

---------

Co-authored-by: João Matos <joao@tritao.eu>
  • Loading branch information
esdrubal and tritao authored Feb 8, 2024
1 parent 046d292 commit 52767cc
Show file tree
Hide file tree
Showing 13 changed files with 115 additions and 276 deletions.
16 changes: 0 additions & 16 deletions sway-core/src/language/ty/ast_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,22 +154,6 @@ impl GetDeclIdent for TyAstNode {
}

impl TyAstNode {
/// recurse into `self` and get any return statements -- used to validate that all returns
/// do indeed return the correct type
/// This does _not_ extract implicit return statements as those are not control flow! This is
/// _only_ for explicit returns.
pub(crate) fn gather_return_statements(&self) -> Vec<&TyExpression> {
match &self.content {
// assignments and reassignments can happen during control flow and can abort
TyAstNodeContent::Declaration(TyDecl::VariableDecl(decl)) => {
decl.body.gather_return_statements()
}
TyAstNodeContent::Expression(exp) => exp.gather_return_statements(),
TyAstNodeContent::Error(_, _) => vec![],
TyAstNodeContent::SideEffect(_) | TyAstNodeContent::Declaration(_) => vec![],
}
}

/// Returns `true` if this AST node will be exported in a library, i.e. it is a public declaration.
pub(crate) fn is_public(&self, decl_engine: &DeclEngine) -> bool {
match &self.content {
Expand Down
8 changes: 0 additions & 8 deletions sway-core/src/language/ty/expression/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -424,14 +424,6 @@ impl TyExpression {
}
}

/// recurse into `self` and get any return statements -- used to validate that all returns
/// do indeed return the correct type
/// This does _not_ extract implicit return statements as those are not control flow! This is
/// _only_ for explicit returns.
pub(crate) fn gather_return_statements(&self) -> Vec<&TyExpression> {
self.expression.gather_return_statements()
}

/// gathers the mutability of the expressions within
pub(crate) fn gather_mutability(&self) -> VariableMutability {
match &self.expression {
Expand Down
117 changes: 0 additions & 117 deletions sway-core/src/language/ty/expression/expression_variant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1456,121 +1456,4 @@ impl TyExpressionVariant {
_ => None,
}
}

/// Recurse into `self` and get any return statements -- used to validate that all returns
/// do indeed return the correct type.
/// This does _not_ extract implicit return statements as those are not control flow! This is
/// _only_ for explicit returns.
pub(crate) fn gather_return_statements(&self) -> Vec<&TyExpression> {
match self {
TyExpressionVariant::MatchExp { desugared, .. } => {
desugared.expression.gather_return_statements()
}
TyExpressionVariant::IfExp {
condition,
then,
r#else,
} => {
let mut buf = condition.gather_return_statements();
buf.append(&mut then.gather_return_statements());
if let Some(ref r#else) = r#else {
buf.append(&mut r#else.gather_return_statements());
}
buf
}
TyExpressionVariant::CodeBlock(TyCodeBlock { contents, .. }) => {
let mut buf = vec![];
for node in contents {
buf.append(&mut node.gather_return_statements())
}
buf
}
TyExpressionVariant::WhileLoop { condition, body } => {
let mut buf = condition.gather_return_statements();
for node in &body.contents {
buf.append(&mut node.gather_return_statements())
}
buf
}
TyExpressionVariant::Reassignment(reassignment) => {
reassignment.rhs.gather_return_statements()
}
TyExpressionVariant::LazyOperator { lhs, rhs, .. } => [lhs, rhs]
.into_iter()
.flat_map(|expr| expr.gather_return_statements())
.collect(),
TyExpressionVariant::Tuple { fields } => fields
.iter()
.flat_map(|expr| expr.gather_return_statements())
.collect(),
TyExpressionVariant::Array {
elem_type: _,
contents,
} => contents
.iter()
.flat_map(|expr| expr.gather_return_statements())
.collect(),
TyExpressionVariant::ArrayIndex { prefix, index } => [prefix, index]
.into_iter()
.flat_map(|expr| expr.gather_return_statements())
.collect(),
TyExpressionVariant::StructFieldAccess { prefix, .. } => {
prefix.gather_return_statements()
}
TyExpressionVariant::TupleElemAccess { prefix, .. } => {
prefix.gather_return_statements()
}
TyExpressionVariant::EnumInstantiation { contents, .. } => contents
.iter()
.flat_map(|expr| expr.gather_return_statements())
.collect(),
TyExpressionVariant::AbiCast { address, .. } => address.gather_return_statements(),
TyExpressionVariant::IntrinsicFunction(intrinsic_function_kind) => {
intrinsic_function_kind
.arguments
.iter()
.flat_map(|expr| expr.gather_return_statements())
.collect()
}
TyExpressionVariant::StructExpression { fields, .. } => fields
.iter()
.flat_map(|field| field.value.gather_return_statements())
.collect(),
TyExpressionVariant::FunctionApplication {
contract_call_params,
arguments,
selector,
..
} => contract_call_params
.values()
.chain(arguments.iter().map(|(_name, expr)| expr))
.chain(
selector
.iter()
.map(|contract_call_params| &*contract_call_params.contract_address),
)
.flat_map(|expr| expr.gather_return_statements())
.collect(),
TyExpressionVariant::EnumTag { exp } => exp.gather_return_statements(),
TyExpressionVariant::UnsafeDowncast { exp, .. } => exp.gather_return_statements(),
TyExpressionVariant::ImplicitReturn(exp) => exp.gather_return_statements(),
TyExpressionVariant::Return(exp) => {
vec![exp]
}
TyExpressionVariant::Ref(exp) | TyExpressionVariant::Deref(exp) => {
exp.gather_return_statements()
}
// if it is impossible for an expression to contain a return _statement_ (not an
// implicit return!), put it in the pattern below.
TyExpressionVariant::Literal(_)
| TyExpressionVariant::FunctionParameter { .. }
| TyExpressionVariant::AsmExpression { .. }
| TyExpressionVariant::ConstantExpression { .. }
| TyExpressionVariant::VariableExpression { .. }
| TyExpressionVariant::AbiName(_)
| TyExpressionVariant::StorageAccess { .. }
| TyExpressionVariant::Break
| TyExpressionVariant::Continue => vec![],
}
}
}
44 changes: 2 additions & 42 deletions sway-core/src/semantic_analysis/ast_node/declaration/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,8 @@ impl ty::TyFunctionDecl {
.with_help_text(
"Function body's return type does not match up with its return type annotation.",
)
.with_type_annotation(return_type.type_id);
.with_type_annotation(return_type.type_id)
.with_function_type_annotation(return_type.type_id);

let body = ty::TyCodeBlock::type_check(handler, ctx.by_ref(), body)
.unwrap_or_else(|_err| ty::TyCodeBlock::default());
Expand All @@ -223,32 +224,6 @@ impl ty::TyFunctionDecl {
}
}

/// Unifies the types of the return statements and the return type of the
/// function declaration.
fn unify_return_statements(
handler: &Handler,
ctx: TypeCheckContext,
return_statements: &[&ty::TyExpression],
return_type: TypeId,
) -> Result<(), ErrorEmitted> {
let type_engine = ctx.engines.te();

handler.scope(|handler| {
for stmt in return_statements.iter() {
type_engine.unify(
handler,
ctx.engines(),
stmt.return_type,
return_type,
&stmt.span,
"Return statement must return the declared function return type.",
None,
);
}
Ok(())
})
}

impl TypeCheckAnalysis for DeclId<TyFunctionDecl> {
fn type_check_analyze(
&self,
Expand Down Expand Up @@ -314,21 +289,6 @@ impl TypeCheckUnification for ty::TyFunctionDecl {

let return_type = &self.return_type;

// gather the return statements
let return_statements: Vec<&ty::TyExpression> = self
.body
.contents
.iter()
.flat_map(|node| node.gather_return_statements())
.collect();

unify_return_statements(
handler,
type_check_ctx.by_ref(),
&return_statements,
return_type.type_id,
)?;

return_type.type_id.check_type_parameter_bounds(
handler,
type_check_ctx.by_ref(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -393,20 +393,12 @@ impl ty::TyExpression {
Ok(typed_expr)
}
ExpressionKind::Return(expr) => {
let function_type_annotation = ctx.function_type_annotation();
let ctx = ctx
// we use "unknown" here because return statements do not
// necessarily follow the type annotation of their immediate
// surrounding context. Because a return statement is control flow
// that breaks out to the nearest function, we need to type check
// it against the surrounding function.
// That is impossible here, as we don't have that information. It
// is the responsibility of the function declaration to type check
// all return statements contained within it.
.by_ref()
.with_type_annotation(type_engine.insert(engines, TypeInfo::Unknown, None))
.with_type_annotation(function_type_annotation)
.with_help_text(
"Returned value must match up with the function return type \
annotation.",
"Return statement must return the declared function return type.",
);
let expr_span = expr.span();
let expr = ty::TyExpression::type_check(handler, ctx, *expr)
Expand Down
17 changes: 17 additions & 0 deletions sway-core/src/semantic_analysis/type_check_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ pub struct TypeCheckContext<'a> {
///
/// Assists type inference.
type_annotation: TypeId,
/// Assists type inference.
function_type_annotation: TypeId,
/// When true unify_with_type_annotation will use unify_with_generic instead of the default unify.
/// This ensures that expected generic types are unified to more specific received types.
unify_generic: bool,
Expand Down Expand Up @@ -118,6 +120,7 @@ impl<'a> TypeCheckContext<'a> {
namespace,
engines,
type_annotation: engines.te().insert(engines, TypeInfo::Unknown, None),
function_type_annotation: engines.te().insert(engines, TypeInfo::Unknown, None),
unify_generic: false,
self_type: None,
type_subst: TypeSubstMap::new(),
Expand Down Expand Up @@ -146,6 +149,7 @@ impl<'a> TypeCheckContext<'a> {
TypeCheckContext {
namespace: self.namespace,
type_annotation: self.type_annotation,
function_type_annotation: self.function_type_annotation,
unify_generic: self.unify_generic,
self_type: self.self_type,
type_subst: self.type_subst.clone(),
Expand All @@ -168,6 +172,7 @@ impl<'a> TypeCheckContext<'a> {
TypeCheckContext {
namespace,
type_annotation: self.type_annotation,
function_type_annotation: self.function_type_annotation,
unify_generic: self.unify_generic,
self_type: self.self_type,
type_subst: self.type_subst,
Expand Down Expand Up @@ -217,6 +222,14 @@ impl<'a> TypeCheckContext<'a> {
}
}

/// Map this `TypeCheckContext` instance to a new one with the given type annotation.
pub(crate) fn with_function_type_annotation(self, function_type_annotation: TypeId) -> Self {
Self {
function_type_annotation,
..self
}
}

/// Map this `TypeCheckContext` instance to a new one with the given type annotation.
pub(crate) fn with_unify_generic(self, unify_generic: bool) -> Self {
Self {
Expand Down Expand Up @@ -322,6 +335,10 @@ impl<'a> TypeCheckContext<'a> {
self.type_annotation
}

pub(crate) fn function_type_annotation(&self) -> TypeId {
self.function_type_annotation
}

pub(crate) fn unify_generic(&self) -> bool {
self.unify_generic
}
Expand Down
Loading

0 comments on commit 52767cc

Please sign in to comment.