diff --git a/crates/red_knot_python_semantic/src/semantic_index.rs b/crates/red_knot_python_semantic/src/semantic_index.rs index 1a60ef729b637..7e2c5a19484b5 100644 --- a/crates/red_knot_python_semantic/src/semantic_index.rs +++ b/crates/red_knot_python_semantic/src/semantic_index.rs @@ -666,7 +666,7 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs): fn comprehension_scope() { let TestCase { db, file } = test_case( " -[x for x in iter1] +[x for x, y in iter1] ", ); @@ -690,7 +690,22 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs): let comprehension_symbol_table = index.symbol_table(comprehension_scope_id); - assert_eq!(names(&comprehension_symbol_table), vec!["x"]); + assert_eq!(names(&comprehension_symbol_table), vec!["x", "y"]); + + let use_def = index.use_def_map(comprehension_scope_id); + for name in ["x", "y"] { + let definition = use_def + .first_public_definition( + comprehension_symbol_table + .symbol_id_by_name(name) + .expect("symbol exists"), + ) + .unwrap(); + assert!(matches!( + definition.node(&db), + DefinitionKind::Comprehension(_) + )); + } } /// Test case to validate that the `x` variable used in the comprehension is referencing the @@ -730,8 +745,8 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs): let DefinitionKind::Comprehension(comprehension) = definition.node(&db) else { panic!("expected generator definition") }; - let ast::Comprehension { target, .. } = comprehension.node(); - let name = target.as_name_expr().unwrap().id().as_str(); + let target = comprehension.target(); + let name = target.id().as_str(); assert_eq!(name, "x"); assert_eq!(target.range(), TextRange::new(23.into(), 24.into())); diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index 3f6d0c23e041b..38637abeb21f7 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -285,6 +285,7 @@ impl<'db> SemanticIndexBuilder<'db> { // The `iter` of the first generator is evaluated in the outer scope, while all subsequent // nodes are evaluated in the inner scope. + self.add_standalone_expression(&generator.iter); self.visit_expr(&generator.iter); self.push_scope(scope); @@ -300,6 +301,7 @@ impl<'db> SemanticIndexBuilder<'db> { } for generator in generators_iter { + self.add_standalone_expression(&generator.iter); self.visit_expr(&generator.iter); self.current_assignment = Some(CurrentAssignment::Comprehension { @@ -678,7 +680,11 @@ where Some(CurrentAssignment::Comprehension { node, first }) => { self.add_definition( symbol, - ComprehensionDefinitionNodeRef { node, first }, + ComprehensionDefinitionNodeRef { + iterable: &node.iter, + target: name_node, + first, + }, ); } Some(CurrentAssignment::WithItem(with_item)) => { diff --git a/crates/red_knot_python_semantic/src/semantic_index/definition.rs b/crates/red_knot_python_semantic/src/semantic_index/definition.rs index 537a17c8c18a0..07c36f7361afa 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/definition.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/definition.rs @@ -156,7 +156,8 @@ pub(crate) struct ForStmtDefinitionNodeRef<'a> { #[derive(Copy, Clone, Debug)] pub(crate) struct ComprehensionDefinitionNodeRef<'a> { - pub(crate) node: &'a ast::Comprehension, + pub(crate) iterable: &'a ast::Expr, + pub(crate) target: &'a ast::ExprName, pub(crate) first: bool, } @@ -211,12 +212,15 @@ impl DefinitionNodeRef<'_> { target: AstNodeRef::new(parsed, target), }) } - DefinitionNodeRef::Comprehension(ComprehensionDefinitionNodeRef { node, first }) => { - DefinitionKind::Comprehension(ComprehensionDefinitionKind { - node: AstNodeRef::new(parsed, node), - first, - }) - } + DefinitionNodeRef::Comprehension(ComprehensionDefinitionNodeRef { + iterable, + target, + first, + }) => DefinitionKind::Comprehension(ComprehensionDefinitionKind { + iterable: AstNodeRef::new(parsed.clone(), iterable), + target: AstNodeRef::new(parsed, target), + first, + }), DefinitionNodeRef::Parameter(parameter) => match parameter { ast::AnyParameterRef::Variadic(parameter) => { DefinitionKind::Parameter(AstNodeRef::new(parsed, parameter)) @@ -262,7 +266,7 @@ impl DefinitionNodeRef<'_> { iterable: _, target, }) => target.into(), - Self::Comprehension(ComprehensionDefinitionNodeRef { node, first: _ }) => node.into(), + Self::Comprehension(ComprehensionDefinitionNodeRef { target, .. }) => target.into(), Self::Parameter(node) => match node { ast::AnyParameterRef::Variadic(parameter) => parameter.into(), ast::AnyParameterRef::NonVariadic(parameter) => parameter.into(), @@ -313,13 +317,18 @@ impl MatchPatternDefinitionKind { #[derive(Clone, Debug)] pub struct ComprehensionDefinitionKind { - node: AstNodeRef, + iterable: AstNodeRef, + target: AstNodeRef, first: bool, } impl ComprehensionDefinitionKind { - pub(crate) fn node(&self) -> &ast::Comprehension { - self.node.node() + pub(crate) fn iterable(&self) -> &ast::Expr { + self.iterable.node() + } + + pub(crate) fn target(&self) -> &ast::ExprName { + self.target.node() } pub(crate) fn is_first(&self) -> bool { @@ -442,12 +451,6 @@ impl From<&ast::StmtFor> for DefinitionNodeKey { } } -impl From<&ast::Comprehension> for DefinitionNodeKey { - fn from(node: &ast::Comprehension) -> Self { - Self(NodeKey::from_node(node)) - } -} - impl From<&ast::Parameter> for DefinitionNodeKey { fn from(node: &ast::Parameter) -> Self { Self(NodeKey::from_node(node)) diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 09a927a4848cd..b9b8f900b731e 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -403,7 +403,8 @@ impl<'db> TypeInferenceBuilder<'db> { } DefinitionKind::Comprehension(comprehension) => { self.infer_comprehension_definition( - comprehension.node(), + comprehension.iterable(), + comprehension.target(), comprehension.is_first(), definition, ); @@ -1545,11 +1546,11 @@ impl<'db> TypeInferenceBuilder<'db> { /// Infer the type of the `iter` expression of the first comprehension. fn infer_first_comprehension_iter(&mut self, comprehensions: &[ast::Comprehension]) { - let mut generators_iter = comprehensions.iter(); - let Some(first_generator) = generators_iter.next() else { + let mut comprehensions_iter = comprehensions.iter(); + let Some(first_comprehension) = comprehensions_iter.next() else { unreachable!("Comprehension must contain at least one generator"); }; - self.infer_expression(&first_generator.iter); + self.infer_expression(&first_comprehension.iter); } fn infer_generator_expression(&mut self, generator: &ast::ExprGenerator) -> Type<'db> { @@ -1615,9 +1616,7 @@ impl<'db> TypeInferenceBuilder<'db> { } = generator; self.infer_expression(elt); - for comprehension in generators { - self.infer_comprehension(comprehension); - } + self.infer_comprehensions(generators); } fn infer_list_comprehension_expression_scope(&mut self, listcomp: &ast::ExprListComp) { @@ -1628,9 +1627,7 @@ impl<'db> TypeInferenceBuilder<'db> { } = listcomp; self.infer_expression(elt); - for comprehension in generators { - self.infer_comprehension(comprehension); - } + self.infer_comprehensions(generators); } fn infer_dict_comprehension_expression_scope(&mut self, dictcomp: &ast::ExprDictComp) { @@ -1643,9 +1640,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_expression(key); self.infer_expression(value); - for comprehension in generators { - self.infer_comprehension(comprehension); - } + self.infer_comprehensions(generators); } fn infer_set_comprehension_expression_scope(&mut self, setcomp: &ast::ExprSetComp) { @@ -1656,37 +1651,68 @@ impl<'db> TypeInferenceBuilder<'db> { } = setcomp; self.infer_expression(elt); - for comprehension in generators { - self.infer_comprehension(comprehension); - } + self.infer_comprehensions(generators); } - fn infer_comprehension(&mut self, comprehension: &ast::Comprehension) { - self.infer_definition(comprehension); - for expr in &comprehension.ifs { - self.infer_expression(expr); + fn infer_comprehensions(&mut self, comprehensions: &[ast::Comprehension]) { + let mut comprehensions_iter = comprehensions.iter(); + let Some(first_comprehension) = comprehensions_iter.next() else { + unreachable!("Comprehension must contain at least one generator"); + }; + self.infer_comprehension(first_comprehension, true); + for comprehension in comprehensions_iter { + self.infer_comprehension(comprehension, false); } } - fn infer_comprehension_definition( - &mut self, - comprehension: &ast::Comprehension, - is_first: bool, - definition: Definition<'db>, - ) { + fn infer_comprehension(&mut self, comprehension: &ast::Comprehension, is_first: bool) { let ast::Comprehension { range: _, target, iter, - ifs: _, + ifs, is_async: _, } = comprehension; if !is_first { self.infer_expression(iter); } - // TODO(dhruvmanila): The target type should be inferred based on the iter type instead. - let target_ty = self.infer_expression(target); + // TODO more complex assignment targets + if let ast::Expr::Name(name) = target { + self.infer_definition(name); + } else { + self.infer_expression(target); + } + for expr in ifs { + self.infer_expression(expr); + } + } + + fn infer_comprehension_definition( + &mut self, + iterable: &ast::Expr, + target: &ast::ExprName, + is_first: bool, + definition: Definition<'db>, + ) { + if !is_first { + let expression = self.index.expression(iterable); + let result = infer_expression_types(self.db, expression); + self.extend(result); + let _iterable_ty = self + .types + .expression_ty(iterable.scoped_ast_id(self.db, self.scope)); + } + // TODO(dhruvmanila): The iter type for the first comprehension is coming from the + // enclosing scope. + + // TODO(dhruvmanila): The target type should be inferred based on the iter type instead, + // similar to how it's done in `infer_for_statement_definition`. + let target_ty = Type::Unknown; + + self.types + .expressions + .insert(target.scoped_ast_id(self.db, self.scope), target_ty); self.types.definitions.insert(definition, target_ty); }