From 5ccc21aea2b40aee927b356cef48a1ca0fbcd6c9 Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Wed, 20 Dec 2023 00:06:31 -0500 Subject: [PATCH] Add support for `NoReturn` in auto-return-typing (#9206) ## Summary Given a function like: ```python def func(x: int): if not x: raise ValueError else: raise TypeError ``` We now correctly use `NoReturn` as the return type, rather than `None`. Closes https://github.com/astral-sh/ruff/issues/9201. --- .../flake8_annotations/auto_return_type.py | 30 ++ .../src/rules/flake8_annotations/helpers.rs | 28 +- ..._annotations__tests__auto_return_type.snap | 84 +++++ ...tations__tests__auto_return_type_py38.snap | 92 +++++ crates/ruff_python_ast/src/helpers.rs | 322 ++++++++++-------- 5 files changed, 406 insertions(+), 150 deletions(-) diff --git a/crates/ruff_linter/resources/test/fixtures/flake8_annotations/auto_return_type.py b/crates/ruff_linter/resources/test/fixtures/flake8_annotations/auto_return_type.py index 60d89cbec5cdc..a1df0985b0acd 100644 --- a/crates/ruff_linter/resources/test/fixtures/flake8_annotations/auto_return_type.py +++ b/crates/ruff_linter/resources/test/fixtures/flake8_annotations/auto_return_type.py @@ -182,3 +182,33 @@ def method(self): return 1 else: return 1.5 + + +def func(x: int): + try: + pass + except: + return 2 + + +def func(x: int): + try: + pass + except: + return 2 + else: + return 3 + + +def func(x: int): + if not x: + raise ValueError + else: + raise TypeError + + +def func(x: int): + if not x: + raise ValueError + else: + return 1 diff --git a/crates/ruff_linter/src/rules/flake8_annotations/helpers.rs b/crates/ruff_linter/src/rules/flake8_annotations/helpers.rs index 41cff286d79c5..baf316357da55 100644 --- a/crates/ruff_linter/src/rules/flake8_annotations/helpers.rs +++ b/crates/ruff_linter/src/rules/flake8_annotations/helpers.rs @@ -3,7 +3,7 @@ use rustc_hash::FxHashSet; use ruff_diagnostics::Edit; use ruff_python_ast::helpers::{ - implicit_return, pep_604_union, typing_optional, typing_union, ReturnStatementVisitor, + pep_604_union, typing_optional, typing_union, ReturnStatementVisitor, Terminal, }; use ruff_python_ast::visitor::Visitor; use ruff_python_ast::{self as ast, Expr, ExprContext}; @@ -57,6 +57,14 @@ pub(crate) fn auto_return_type(function: &ast::StmtFunctionDef) -> Option Option 0: // return 1 // ``` - if implicit_return(function) { + if terminal.is_none() { return_type = return_type.union(ResolvedPythonType::Atom(PythonType::None)); } @@ -94,6 +102,7 @@ pub(crate) fn auto_return_type(function: &ast::StmtFunctionDef) -> Option), } @@ -111,6 +120,21 @@ impl AutoPythonType { target_version: PythonVersion, ) -> Option<(Expr, Vec)> { match self { + AutoPythonType::NoReturn => { + let (no_return_edit, binding) = importer + .get_or_import_symbol( + &ImportRequest::import_from("typing", "NoReturn"), + at, + semantic, + ) + .ok()?; + let expr = Expr::Name(ast::ExprName { + id: binding, + range: TextRange::default(), + ctx: ExprContext::Load, + }); + Some((expr, vec![no_return_edit])) + } AutoPythonType::Atom(python_type) => { let expr = type_expr(python_type)?; Some((expr, vec![])) diff --git a/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__auto_return_type.snap b/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__auto_return_type.snap index 72b8e8b9a66cc..1dc85efd74b0d 100644 --- a/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__auto_return_type.snap +++ b/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__auto_return_type.snap @@ -495,4 +495,88 @@ auto_return_type.py:180:9: ANN201 [*] Missing return type annotation for public 182 182 | return 1 183 183 | else: +auto_return_type.py:187:5: ANN201 [*] Missing return type annotation for public function `func` + | +187 | def func(x: int): + | ^^^^ ANN201 +188 | try: +189 | pass + | + = help: Add return type annotation: `int | None` + +ℹ Unsafe fix +184 184 | return 1.5 +185 185 | +186 186 | +187 |-def func(x: int): + 187 |+def func(x: int) -> int | None: +188 188 | try: +189 189 | pass +190 190 | except: + +auto_return_type.py:194:5: ANN201 [*] Missing return type annotation for public function `func` + | +194 | def func(x: int): + | ^^^^ ANN201 +195 | try: +196 | pass + | + = help: Add return type annotation: `int` + +ℹ Unsafe fix +191 191 | return 2 +192 192 | +193 193 | +194 |-def func(x: int): + 194 |+def func(x: int) -> int: +195 195 | try: +196 196 | pass +197 197 | except: + +auto_return_type.py:203:5: ANN201 [*] Missing return type annotation for public function `func` + | +203 | def func(x: int): + | ^^^^ ANN201 +204 | if not x: +205 | raise ValueError + | + = help: Add return type annotation: `NoReturn` + +ℹ Unsafe fix +151 151 | +152 152 | import abc +153 153 | from abc import abstractmethod + 154 |+from typing import NoReturn +154 155 | +155 156 | +156 157 | class Foo(abc.ABC): +-------------------------------------------------------------------------------- +200 201 | return 3 +201 202 | +202 203 | +203 |-def func(x: int): + 204 |+def func(x: int) -> NoReturn: +204 205 | if not x: +205 206 | raise ValueError +206 207 | else: + +auto_return_type.py:210:5: ANN201 [*] Missing return type annotation for public function `func` + | +210 | def func(x: int): + | ^^^^ ANN201 +211 | if not x: +212 | raise ValueError + | + = help: Add return type annotation: `int` + +ℹ Unsafe fix +207 207 | raise TypeError +208 208 | +209 209 | +210 |-def func(x: int): + 210 |+def func(x: int) -> int: +211 211 | if not x: +212 212 | raise ValueError +213 213 | else: + diff --git a/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__auto_return_type_py38.snap b/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__auto_return_type_py38.snap index 7667c2299de9e..a2fb6448f7cdc 100644 --- a/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__auto_return_type_py38.snap +++ b/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__auto_return_type_py38.snap @@ -550,4 +550,96 @@ auto_return_type.py:180:9: ANN201 [*] Missing return type annotation for public 182 182 | return 1 183 183 | else: +auto_return_type.py:187:5: ANN201 [*] Missing return type annotation for public function `func` + | +187 | def func(x: int): + | ^^^^ ANN201 +188 | try: +189 | pass + | + = help: Add return type annotation: `Optional[int]` + +ℹ Unsafe fix +151 151 | +152 152 | import abc +153 153 | from abc import abstractmethod + 154 |+from typing import Optional +154 155 | +155 156 | +156 157 | class Foo(abc.ABC): +-------------------------------------------------------------------------------- +184 185 | return 1.5 +185 186 | +186 187 | +187 |-def func(x: int): + 188 |+def func(x: int) -> Optional[int]: +188 189 | try: +189 190 | pass +190 191 | except: + +auto_return_type.py:194:5: ANN201 [*] Missing return type annotation for public function `func` + | +194 | def func(x: int): + | ^^^^ ANN201 +195 | try: +196 | pass + | + = help: Add return type annotation: `int` + +ℹ Unsafe fix +191 191 | return 2 +192 192 | +193 193 | +194 |-def func(x: int): + 194 |+def func(x: int) -> int: +195 195 | try: +196 196 | pass +197 197 | except: + +auto_return_type.py:203:5: ANN201 [*] Missing return type annotation for public function `func` + | +203 | def func(x: int): + | ^^^^ ANN201 +204 | if not x: +205 | raise ValueError + | + = help: Add return type annotation: `NoReturn` + +ℹ Unsafe fix +151 151 | +152 152 | import abc +153 153 | from abc import abstractmethod + 154 |+from typing import NoReturn +154 155 | +155 156 | +156 157 | class Foo(abc.ABC): +-------------------------------------------------------------------------------- +200 201 | return 3 +201 202 | +202 203 | +203 |-def func(x: int): + 204 |+def func(x: int) -> NoReturn: +204 205 | if not x: +205 206 | raise ValueError +206 207 | else: + +auto_return_type.py:210:5: ANN201 [*] Missing return type annotation for public function `func` + | +210 | def func(x: int): + | ^^^^ ANN201 +211 | if not x: +212 | raise ValueError + | + = help: Add return type annotation: `int` + +ℹ Unsafe fix +207 207 | raise TypeError +208 208 | +209 209 | +210 |-def func(x: int): + 210 |+def func(x: int) -> int: +211 211 | if not x: +212 212 | raise ValueError +213 213 | else: + diff --git a/crates/ruff_python_ast/src/helpers.rs b/crates/ruff_python_ast/src/helpers.rs index fceba83d02a80..61e80657b7a5e 100644 --- a/crates/ruff_python_ast/src/helpers.rs +++ b/crates/ruff_python_ast/src/helpers.rs @@ -921,178 +921,204 @@ where } } -/// Returns `true` if the function has an implicit return. -pub fn implicit_return(function: &ast::StmtFunctionDef) -> bool { - /// Returns `true` if the body may break via a `break` statement. - fn sometimes_breaks(stmts: &[Stmt]) -> bool { - for stmt in stmts { - match stmt { - Stmt::For(ast::StmtFor { body, orelse, .. }) => { - if returns(body) { - return false; - } - if sometimes_breaks(orelse) { - return true; - } - } - Stmt::While(ast::StmtWhile { body, orelse, .. }) => { - if returns(body) { - return false; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Terminal { + /// Every path through the function ends with a `raise` statement. + Raise, + /// Every path through the function ends with a `return` (or `raise`) statement. + Return, +} + +impl Terminal { + /// Returns the [`Terminal`] behavior of the function, if it can be determined, or `None` if the + /// function contains at least one control flow path that does not end with a `return` or `raise` + /// statement. + pub fn from_function(function: &ast::StmtFunctionDef) -> Option { + /// Returns `true` if the body may break via a `break` statement. + fn sometimes_breaks(stmts: &[Stmt]) -> bool { + for stmt in stmts { + match stmt { + Stmt::For(ast::StmtFor { body, orelse, .. }) => { + if returns(body).is_some() { + return false; + } + if sometimes_breaks(orelse) { + return true; + } } - if sometimes_breaks(orelse) { - return true; + Stmt::While(ast::StmtWhile { body, orelse, .. }) => { + if returns(body).is_some() { + return false; + } + if sometimes_breaks(orelse) { + return true; + } } - } - Stmt::If(ast::StmtIf { - body, - elif_else_clauses, - .. - }) => { - if std::iter::once(body) - .chain(elif_else_clauses.iter().map(|clause| &clause.body)) - .any(|body| sometimes_breaks(body)) - { - return true; + Stmt::If(ast::StmtIf { + body, + elif_else_clauses, + .. + }) => { + if std::iter::once(body) + .chain(elif_else_clauses.iter().map(|clause| &clause.body)) + .any(|body| sometimes_breaks(body)) + { + return true; + } } - } - Stmt::Match(ast::StmtMatch { cases, .. }) => { - if cases.iter().any(|case| sometimes_breaks(&case.body)) { - return true; + Stmt::Match(ast::StmtMatch { cases, .. }) => { + if cases.iter().any(|case| sometimes_breaks(&case.body)) { + return true; + } } - } - Stmt::Try(ast::StmtTry { - body, - handlers, - orelse, - finalbody, - .. - }) => { - if sometimes_breaks(body) - || handlers.iter().any(|handler| { - let ExceptHandler::ExceptHandler(ast::ExceptHandlerExceptHandler { - body, - .. - }) = handler; - sometimes_breaks(body) - }) - || sometimes_breaks(orelse) - || sometimes_breaks(finalbody) - { - return true; + Stmt::Try(ast::StmtTry { + body, + handlers, + orelse, + finalbody, + .. + }) => { + if sometimes_breaks(body) + || handlers.iter().any(|handler| { + let ExceptHandler::ExceptHandler(ast::ExceptHandlerExceptHandler { + body, + .. + }) = handler; + sometimes_breaks(body) + }) + || sometimes_breaks(orelse) + || sometimes_breaks(finalbody) + { + return true; + } } - } - Stmt::With(ast::StmtWith { body, .. }) => { - if sometimes_breaks(body) { - return true; + Stmt::With(ast::StmtWith { body, .. }) => { + if sometimes_breaks(body) { + return true; + } } + Stmt::Break(_) => return true, + Stmt::Return(_) => return false, + Stmt::Raise(_) => return false, + _ => {} } - Stmt::Break(_) => return true, - Stmt::Return(_) => return false, - Stmt::Raise(_) => return false, - _ => {} } + false } - false - } - /// Returns `true` if the body may break via a `break` statement. - fn always_breaks(stmts: &[Stmt]) -> bool { - for stmt in stmts { - match stmt { - Stmt::Break(_) => return true, - Stmt::Return(_) => return false, - Stmt::Raise(_) => return false, - _ => {} + /// Returns `true` if the body may break via a `break` statement. + fn always_breaks(stmts: &[Stmt]) -> bool { + for stmt in stmts { + match stmt { + Stmt::Break(_) => return true, + Stmt::Return(_) => return false, + Stmt::Raise(_) => return false, + _ => {} + } } + false } - false - } - /// Returns `true` if the body contains a branch that ends without an explicit `return` or - /// `raise` statement. - fn returns(stmts: &[Stmt]) -> bool { - for stmt in stmts.iter().rev() { - match stmt { - Stmt::For(ast::StmtFor { body, orelse, .. }) => { - if always_breaks(body) { - return false; - } - if returns(body) { - return true; - } - if returns(orelse) && !sometimes_breaks(body) { - return true; - } - } - Stmt::While(ast::StmtWhile { body, orelse, .. }) => { - if always_breaks(body) { - return false; - } - if returns(body) { - return true; - } - if returns(orelse) && !sometimes_breaks(body) { - return true; - } - } - Stmt::If(ast::StmtIf { - body, - elif_else_clauses, - .. - }) => { - if elif_else_clauses.iter().any(|clause| clause.test.is_none()) - && std::iter::once(body) - .chain(elif_else_clauses.iter().map(|clause| &clause.body)) - .all(|body| returns(body)) - { - return true; + /// Returns `true` if the body contains a branch that ends without an explicit `return` or + /// `raise` statement. + fn returns(stmts: &[Stmt]) -> Option { + for stmt in stmts.iter().rev() { + match stmt { + Stmt::For(ast::StmtFor { body, orelse, .. }) + | Stmt::While(ast::StmtWhile { body, orelse, .. }) => { + if always_breaks(body) { + return None; + } + if let Some(terminal) = returns(body) { + return Some(terminal); + } + if !sometimes_breaks(body) { + if let Some(terminal) = returns(orelse) { + return Some(terminal); + } + } } - } - Stmt::Match(ast::StmtMatch { cases, .. }) => { - // Note: we assume the `match` is exhaustive. - if cases.iter().all(|case| returns(&case.body)) { - return true; + Stmt::If(ast::StmtIf { + body, + elif_else_clauses, + .. + }) => { + if elif_else_clauses.iter().any(|clause| clause.test.is_none()) { + match Terminal::combine(std::iter::once(returns(body)).chain( + elif_else_clauses.iter().map(|clause| returns(&clause.body)), + )) { + Some(Terminal::Raise) => return Some(Terminal::Raise), + Some(Terminal::Return) => return Some(Terminal::Return), + _ => {} + } + } } - } - Stmt::Try(ast::StmtTry { - body, - handlers, - orelse, - finalbody, - .. - }) => { - // If the `finally` block returns, the `try` block must also return. - if returns(finalbody) { - return true; + Stmt::Match(ast::StmtMatch { cases, .. }) => { + // Note: we assume the `match` is exhaustive. + match Terminal::combine(cases.iter().map(|case| returns(&case.body))) { + Some(Terminal::Raise) => return Some(Terminal::Raise), + Some(Terminal::Return) => return Some(Terminal::Return), + _ => {} + } } + Stmt::Try(ast::StmtTry { + body, + handlers, + orelse, + finalbody, + .. + }) => { + // If the `finally` block returns, the `try` block must also return. + if let Some(terminal) = returns(finalbody) { + return Some(terminal); + } - // If the `body` or the `else` block returns, the `try` block must also return. - if (returns(body) || returns(orelse)) - && handlers.iter().all(|handler| { - let ExceptHandler::ExceptHandler(ast::ExceptHandlerExceptHandler { - body, - .. - }) = handler; - returns(body) - }) - { - return true; + // If the body returns, the `try` block must also return. + if returns(body) == Some(Terminal::Return) { + return Some(Terminal::Return); + } + + // If the else block and all the handlers return, the `try` block must also + // return. + if let Some(terminal) = + Terminal::combine(std::iter::once(returns(orelse)).chain( + handlers.iter().map(|handler| { + let ExceptHandler::ExceptHandler( + ast::ExceptHandlerExceptHandler { body, .. }, + ) = handler; + returns(body) + }), + )) + { + return Some(terminal); + } } - } - Stmt::With(ast::StmtWith { body, .. }) => { - if returns(body) { - return true; + Stmt::With(ast::StmtWith { body, .. }) => { + if let Some(terminal) = returns(body) { + return Some(terminal); + } } + Stmt::Return(_) => return Some(Terminal::Return), + Stmt::Raise(_) => return Some(Terminal::Raise), + _ => {} } - Stmt::Return(_) => return true, - Stmt::Raise(_) => return true, - _ => {} } + None } - false + + returns(&function.body) } - !returns(&function.body) + /// Combine a series of [`Terminal`] operators. + fn combine(iter: impl Iterator>) -> Option { + iter.reduce(|acc, terminal| match (acc, terminal) { + (Some(Self::Raise), Some(Self::Raise)) => Some(Self::Raise), + (Some(_), Some(Self::Return)) => Some(Self::Return), + (Some(Self::Return), Some(_)) => Some(Self::Return), + _ => None, + }) + .flatten() + } } /// A [`StatementVisitor`] that collects all `raise` statements in a function or method.