Skip to content

Commit

Permalink
[red-knot] Handle multiple comprehension targets (#13213)
Browse files Browse the repository at this point in the history
## Summary

Part of #13085, this PR updates the comprehension definition to handle
multiple targets.

## Test Plan

Update existing semantic index test case for comprehension with multiple
targets. Running corpus tests shouldn't panic.
  • Loading branch information
dhruvmanila authored Sep 4, 2024
1 parent 3c4ec82 commit e1e9143
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 51 deletions.
23 changes: 19 additions & 4 deletions crates/red_knot_python_semantic/src/semantic_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
fn comprehension_scope() {
let TestCase { db, file } = test_case(
"
[x for x in iter1]
[x for x, y in iter1]
",
);

Expand All @@ -690,7 +690,22 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):

let comprehension_symbol_table = index.symbol_table(comprehension_scope_id);

assert_eq!(names(&comprehension_symbol_table), vec!["x"]);
assert_eq!(names(&comprehension_symbol_table), vec!["x", "y"]);

let use_def = index.use_def_map(comprehension_scope_id);
for name in ["x", "y"] {
let definition = use_def
.first_public_definition(
comprehension_symbol_table
.symbol_id_by_name(name)
.expect("symbol exists"),
)
.unwrap();
assert!(matches!(
definition.node(&db),
DefinitionKind::Comprehension(_)
));
}
}

/// Test case to validate that the `x` variable used in the comprehension is referencing the
Expand Down Expand Up @@ -730,8 +745,8 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
let DefinitionKind::Comprehension(comprehension) = definition.node(&db) else {
panic!("expected generator definition")
};
let ast::Comprehension { target, .. } = comprehension.node();
let name = target.as_name_expr().unwrap().id().as_str();
let target = comprehension.target();
let name = target.id().as_str();

assert_eq!(name, "x");
assert_eq!(target.range(), TextRange::new(23.into(), 24.into()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ impl<'db> SemanticIndexBuilder<'db> {

// The `iter` of the first generator is evaluated in the outer scope, while all subsequent
// nodes are evaluated in the inner scope.
self.add_standalone_expression(&generator.iter);
self.visit_expr(&generator.iter);
self.push_scope(scope);

Expand All @@ -300,6 +301,7 @@ impl<'db> SemanticIndexBuilder<'db> {
}

for generator in generators_iter {
self.add_standalone_expression(&generator.iter);
self.visit_expr(&generator.iter);

self.current_assignment = Some(CurrentAssignment::Comprehension {
Expand Down Expand Up @@ -678,7 +680,11 @@ where
Some(CurrentAssignment::Comprehension { node, first }) => {
self.add_definition(
symbol,
ComprehensionDefinitionNodeRef { node, first },
ComprehensionDefinitionNodeRef {
iterable: &node.iter,
target: name_node,
first,
},
);
}
Some(CurrentAssignment::WithItem(with_item)) => {
Expand Down
37 changes: 20 additions & 17 deletions crates/red_knot_python_semantic/src/semantic_index/definition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ pub(crate) struct ForStmtDefinitionNodeRef<'a> {

#[derive(Copy, Clone, Debug)]
pub(crate) struct ComprehensionDefinitionNodeRef<'a> {
pub(crate) node: &'a ast::Comprehension,
pub(crate) iterable: &'a ast::Expr,
pub(crate) target: &'a ast::ExprName,
pub(crate) first: bool,
}

Expand Down Expand Up @@ -211,12 +212,15 @@ impl DefinitionNodeRef<'_> {
target: AstNodeRef::new(parsed, target),
})
}
DefinitionNodeRef::Comprehension(ComprehensionDefinitionNodeRef { node, first }) => {
DefinitionKind::Comprehension(ComprehensionDefinitionKind {
node: AstNodeRef::new(parsed, node),
first,
})
}
DefinitionNodeRef::Comprehension(ComprehensionDefinitionNodeRef {
iterable,
target,
first,
}) => DefinitionKind::Comprehension(ComprehensionDefinitionKind {
iterable: AstNodeRef::new(parsed.clone(), iterable),
target: AstNodeRef::new(parsed, target),
first,
}),
DefinitionNodeRef::Parameter(parameter) => match parameter {
ast::AnyParameterRef::Variadic(parameter) => {
DefinitionKind::Parameter(AstNodeRef::new(parsed, parameter))
Expand Down Expand Up @@ -262,7 +266,7 @@ impl DefinitionNodeRef<'_> {
iterable: _,
target,
}) => target.into(),
Self::Comprehension(ComprehensionDefinitionNodeRef { node, first: _ }) => node.into(),
Self::Comprehension(ComprehensionDefinitionNodeRef { target, .. }) => target.into(),
Self::Parameter(node) => match node {
ast::AnyParameterRef::Variadic(parameter) => parameter.into(),
ast::AnyParameterRef::NonVariadic(parameter) => parameter.into(),
Expand Down Expand Up @@ -313,13 +317,18 @@ impl MatchPatternDefinitionKind {

#[derive(Clone, Debug)]
pub struct ComprehensionDefinitionKind {
node: AstNodeRef<ast::Comprehension>,
iterable: AstNodeRef<ast::Expr>,
target: AstNodeRef<ast::ExprName>,
first: bool,
}

impl ComprehensionDefinitionKind {
pub(crate) fn node(&self) -> &ast::Comprehension {
self.node.node()
pub(crate) fn iterable(&self) -> &ast::Expr {
self.iterable.node()
}

pub(crate) fn target(&self) -> &ast::ExprName {
self.target.node()
}

pub(crate) fn is_first(&self) -> bool {
Expand Down Expand Up @@ -442,12 +451,6 @@ impl From<&ast::StmtFor> for DefinitionNodeKey {
}
}

impl From<&ast::Comprehension> for DefinitionNodeKey {
fn from(node: &ast::Comprehension) -> Self {
Self(NodeKey::from_node(node))
}
}

impl From<&ast::Parameter> for DefinitionNodeKey {
fn from(node: &ast::Parameter) -> Self {
Self(NodeKey::from_node(node))
Expand Down
84 changes: 55 additions & 29 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,8 @@ impl<'db> TypeInferenceBuilder<'db> {
}
DefinitionKind::Comprehension(comprehension) => {
self.infer_comprehension_definition(
comprehension.node(),
comprehension.iterable(),
comprehension.target(),
comprehension.is_first(),
definition,
);
Expand Down Expand Up @@ -1545,11 +1546,11 @@ impl<'db> TypeInferenceBuilder<'db> {

/// Infer the type of the `iter` expression of the first comprehension.
fn infer_first_comprehension_iter(&mut self, comprehensions: &[ast::Comprehension]) {
let mut generators_iter = comprehensions.iter();
let Some(first_generator) = generators_iter.next() else {
let mut comprehensions_iter = comprehensions.iter();
let Some(first_comprehension) = comprehensions_iter.next() else {
unreachable!("Comprehension must contain at least one generator");
};
self.infer_expression(&first_generator.iter);
self.infer_expression(&first_comprehension.iter);
}

fn infer_generator_expression(&mut self, generator: &ast::ExprGenerator) -> Type<'db> {
Expand Down Expand Up @@ -1615,9 +1616,7 @@ impl<'db> TypeInferenceBuilder<'db> {
} = generator;

self.infer_expression(elt);
for comprehension in generators {
self.infer_comprehension(comprehension);
}
self.infer_comprehensions(generators);
}

fn infer_list_comprehension_expression_scope(&mut self, listcomp: &ast::ExprListComp) {
Expand All @@ -1628,9 +1627,7 @@ impl<'db> TypeInferenceBuilder<'db> {
} = listcomp;

self.infer_expression(elt);
for comprehension in generators {
self.infer_comprehension(comprehension);
}
self.infer_comprehensions(generators);
}

fn infer_dict_comprehension_expression_scope(&mut self, dictcomp: &ast::ExprDictComp) {
Expand All @@ -1643,9 +1640,7 @@ impl<'db> TypeInferenceBuilder<'db> {

self.infer_expression(key);
self.infer_expression(value);
for comprehension in generators {
self.infer_comprehension(comprehension);
}
self.infer_comprehensions(generators);
}

fn infer_set_comprehension_expression_scope(&mut self, setcomp: &ast::ExprSetComp) {
Expand All @@ -1656,37 +1651,68 @@ impl<'db> TypeInferenceBuilder<'db> {
} = setcomp;

self.infer_expression(elt);
for comprehension in generators {
self.infer_comprehension(comprehension);
}
self.infer_comprehensions(generators);
}

fn infer_comprehension(&mut self, comprehension: &ast::Comprehension) {
self.infer_definition(comprehension);
for expr in &comprehension.ifs {
self.infer_expression(expr);
fn infer_comprehensions(&mut self, comprehensions: &[ast::Comprehension]) {
let mut comprehensions_iter = comprehensions.iter();
let Some(first_comprehension) = comprehensions_iter.next() else {
unreachable!("Comprehension must contain at least one generator");
};
self.infer_comprehension(first_comprehension, true);
for comprehension in comprehensions_iter {
self.infer_comprehension(comprehension, false);
}
}

fn infer_comprehension_definition(
&mut self,
comprehension: &ast::Comprehension,
is_first: bool,
definition: Definition<'db>,
) {
fn infer_comprehension(&mut self, comprehension: &ast::Comprehension, is_first: bool) {
let ast::Comprehension {
range: _,
target,
iter,
ifs: _,
ifs,
is_async: _,
} = comprehension;

if !is_first {
self.infer_expression(iter);
}
// TODO(dhruvmanila): The target type should be inferred based on the iter type instead.
let target_ty = self.infer_expression(target);
// TODO more complex assignment targets
if let ast::Expr::Name(name) = target {
self.infer_definition(name);
} else {
self.infer_expression(target);
}
for expr in ifs {
self.infer_expression(expr);
}
}

fn infer_comprehension_definition(
&mut self,
iterable: &ast::Expr,
target: &ast::ExprName,
is_first: bool,
definition: Definition<'db>,
) {
if !is_first {
let expression = self.index.expression(iterable);
let result = infer_expression_types(self.db, expression);
self.extend(result);
let _iterable_ty = self
.types
.expression_ty(iterable.scoped_ast_id(self.db, self.scope));
}
// TODO(dhruvmanila): The iter type for the first comprehension is coming from the
// enclosing scope.

// TODO(dhruvmanila): The target type should be inferred based on the iter type instead,
// similar to how it's done in `infer_for_statement_definition`.
let target_ty = Type::Unknown;

self.types
.expressions
.insert(target.scoped_ast_id(self.db, self.scope), target_ty);
self.types.definitions.insert(definition, target_ty);
}

Expand Down

0 comments on commit e1e9143

Please sign in to comment.