Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[red-knot] Emit a diagnostic if the value of a starred expression or a yield from expression is not iterable #13240

Merged
merged 5 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 108 additions & 10 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use infer::TypeInferenceBuilder;
use ruff_db::files::File;
use ruff_python_ast as ast;

Expand Down Expand Up @@ -400,28 +401,42 @@ impl<'db> Type<'db> {
/// for y in x:
/// pass
/// ```
///
/// Returns `None` if `self` represents a type that is not iterable.
fn iterate(&self, db: &'db dyn Db) -> Option<Type<'db>> {
fn iterate(&self, db: &'db dyn Db) -> IterationOutcome<'db> {
// `self` represents the type of the iterable;
// `__iter__` and `__next__` are both looked up on the class of the iterable:
let type_of_class = self.to_meta_type(db);
let iterable_meta_type = self.to_meta_type(db);

let dunder_iter_method = type_of_class.member(db, "__iter__");
let dunder_iter_method = iterable_meta_type.member(db, "__iter__");
if !dunder_iter_method.is_unbound() {
let iterator_ty = dunder_iter_method.call(db)?;
let Some(iterator_ty) = dunder_iter_method.call(db) else {
return IterationOutcome::NotIterable {
not_iterable_ty: *self,
};
};

let dunder_next_method = iterator_ty.to_meta_type(db).member(db, "__next__");
return dunder_next_method.call(db);
return dunder_next_method
.call(db)
.map(|element_ty| IterationOutcome::Iterable { element_ty })
.unwrap_or(IterationOutcome::NotIterable {
not_iterable_ty: *self,
});
}

// Although it's not considered great practice,
// classes that define `__getitem__` are also iterable,
// even if they do not define `__iter__`.
//
// TODO this is only valid if the `__getitem__` method is annotated as
// TODO(Alex) this is only valid if the `__getitem__` method is annotated as
// accepting `int` or `SupportsIndex`
let dunder_get_item_method = type_of_class.member(db, "__getitem__");
dunder_get_item_method.call(db)
let dunder_get_item_method = iterable_meta_type.member(db, "__getitem__");

dunder_get_item_method
.call(db)
.map(|element_ty| IterationOutcome::Iterable { element_ty })
.unwrap_or(IterationOutcome::NotIterable {
not_iterable_ty: *self,
})
}

#[must_use]
Expand Down Expand Up @@ -463,6 +478,28 @@ impl<'db> Type<'db> {
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum IterationOutcome<'db> {
Iterable { element_ty: Type<'db> },
NotIterable { not_iterable_ty: Type<'db> },
}

impl<'db> IterationOutcome<'db> {
fn unwrap_with_diagnostic(
self,
iterable_node: ast::AnyNodeRef,
inference_builder: &mut TypeInferenceBuilder<'db>,
) -> Type<'db> {
match self {
Self::Iterable { element_ty } => element_ty,
Self::NotIterable { not_iterable_ty } => {
inference_builder.not_iterable_diagnostic(iterable_node, not_iterable_ty);
Type::Unknown
}
}
}
}

#[salsa::interned]
pub struct FunctionType<'db> {
/// name of the function at definition
Expand Down Expand Up @@ -789,4 +826,65 @@ mod tests {
&["Object of type 'NotIterable' is not iterable"],
);
}

#[test]
fn starred_expressions_must_be_iterable() {
let mut db = setup_db();

db.write_dedented(
"src/a.py",
"
class NotIterable: pass

class Iterator:
def __next__(self) -> int:
return 42

class Iterable:
def __iter__(self) -> Iterator:

x = [*NotIterable()]
y = [*Iterable()]
",
)
.unwrap();

let a_file = system_path_to_file(&db, "/src/a.py").unwrap();
let a_file_diagnostics = super::check_types(&db, a_file);
assert_diagnostic_messages(
&a_file_diagnostics,
&["Object of type 'NotIterable' is not iterable"],
);
}

#[test]
fn yield_from_expression_must_be_iterable() {
let mut db = setup_db();

db.write_dedented(
"src/a.py",
"
class NotIterable: pass

class Iterator:
def __next__(self) -> int:
return 42

class Iterable:
def __iter__(self) -> Iterator:

def generator_function():
yield from Iterable()
yield from NotIterable()
",
)
.unwrap();

let a_file = system_path_to_file(&db, "/src/a.py").unwrap();
let a_file_diagnostics = super::check_types(&db, a_file);
assert_diagnostic_messages(
&a_file_diagnostics,
&["Object of type 'NotIterable' is not iterable"],
);
}
}
40 changes: 25 additions & 15 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ impl<'db> TypeInference<'db> {
/// Similarly, when we encounter a standalone-inferable expression (right-hand side of an
/// assignment, type narrowing guard), we use the [`infer_expression_types()`] query to ensure we
/// don't infer its types more than once.
struct TypeInferenceBuilder<'db> {
pub(super) struct TypeInferenceBuilder<'db> {
db: &'db dyn Db,
index: &'db SemanticIndex<'db>,
region: InferenceRegion<'db>,
Expand Down Expand Up @@ -1029,6 +1029,18 @@ impl<'db> TypeInferenceBuilder<'db> {
self.infer_body(orelse);
}

/// Emit a diagnostic declaring that the object represented by `node` is not iterable
pub(super) fn not_iterable_diagnostic(&mut self, node: AnyNodeRef, not_iterable_ty: Type<'db>) {
self.add_diagnostic(
node,
"not-iterable",
format_args!(
"Object of type '{}' is not iterable",
not_iterable_ty.display(self.db)
),
);
}

fn infer_for_statement_definition(
&mut self,
target: &ast::ExprName,
Expand All @@ -1042,17 +1054,9 @@ impl<'db> TypeInferenceBuilder<'db> {
.types
.expression_ty(iterable.scoped_ast_id(self.db, self.scope));

let loop_var_value_ty = iterable_ty.iterate(self.db).unwrap_or_else(|| {
self.add_diagnostic(
iterable.into(),
"not-iterable",
format_args!(
"Object of type '{}' is not iterable",
iterable_ty.display(self.db)
),
);
Type::Unknown
});
let loop_var_value_ty = iterable_ty
.iterate(self.db)
.unwrap_with_diagnostic(iterable.into(), self);

self.types
.expressions
Expand Down Expand Up @@ -1812,7 +1816,10 @@ impl<'db> TypeInferenceBuilder<'db> {
ctx: _,
} = starred;

self.infer_expression(value);
let iterable_ty = self.infer_expression(value);
iterable_ty
.iterate(self.db)
.unwrap_with_diagnostic(value.as_ref().into(), self);

// TODO
Type::Unknown
Expand All @@ -1830,9 +1837,12 @@ impl<'db> TypeInferenceBuilder<'db> {
fn infer_yield_from_expression(&mut self, yield_from: &ast::ExprYieldFrom) -> Type<'db> {
let ast::ExprYieldFrom { range: _, value } = yield_from;

self.infer_expression(value);
let iterable_ty = self.infer_expression(value);
iterable_ty
.iterate(self.db)
.unwrap_with_diagnostic(value.as_ref().into(), self);

// TODO get type from awaitable
// TODO get type from `ReturnType` of generator
Type::Unknown
}

Expand Down
Loading