Skip to content

Commit

Permalink
[red-knot] Add definition for with items (#12920)
Browse files Browse the repository at this point in the history
## Summary

This PR adds symbols and definitions introduced by `with` statements.

The symbols and definitions are introduced for each with item. The type
inference is updated to call the definition region type inference
instead.

## Test Plan

Add test case to check for symbol table and definitions.
  • Loading branch information
dhruvmanila committed Aug 22, 2024
1 parent dce87c2 commit 8144a11
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 3 deletions.
50 changes: 50 additions & 0 deletions crates/red_knot_python_semantic/src/semantic_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,56 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
assert_eq!(names(&inner_comprehension_symbol_table), vec!["x"]);
}

#[test]
fn with_item_definition() {
let TestCase { db, file } = test_case(
"
with item1 as x, item2 as y:
pass
",
);

let index = semantic_index(&db, file);
let global_table = index.symbol_table(FileScopeId::global());

assert_eq!(names(&global_table), vec!["item1", "x", "item2", "y"]);

let use_def = index.use_def_map(FileScopeId::global());
for name in ["x", "y"] {
let Some(definition) = use_def.first_public_definition(
global_table.symbol_id_by_name(name).expect("symbol exists"),
) else {
panic!("Expected with item definition for {name}");
};
assert!(matches!(definition.node(&db), DefinitionKind::WithItem(_)));
}
}

#[test]
fn with_item_unpacked_definition() {
let TestCase { db, file } = test_case(
"
with context() as (x, y):
pass
",
);

let index = semantic_index(&db, file);
let global_table = index.symbol_table(FileScopeId::global());

assert_eq!(names(&global_table), vec!["context", "x", "y"]);

let use_def = index.use_def_map(FileScopeId::global());
for name in ["x", "y"] {
let Some(definition) = use_def.first_public_definition(
global_table.symbol_id_by_name(name).expect("symbol exists"),
) else {
panic!("Expected with item definition for {name}");
};
assert!(matches!(definition.node(&db), DefinitionKind::WithItem(_)));
}
}

#[test]
fn dupes() {
let TestCase { db, file } = test_case(
Expand Down
30 changes: 30 additions & 0 deletions crates/red_knot_python_semantic/src/semantic_index/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ use crate::semantic_index::use_def::{FlowSnapshot, UseDefMapBuilder};
use crate::semantic_index::SemanticIndex;
use crate::Db;

use super::definition::WithItemDefinitionNodeRef;

pub(super) struct SemanticIndexBuilder<'db> {
// Builder state
db: &'db dyn Db,
Expand Down Expand Up @@ -561,6 +563,18 @@ where
self.flow_merge(break_state);
}
}
ast::Stmt::With(ast::StmtWith { items, body, .. }) => {
for item in items {
self.visit_expr(&item.context_expr);
if let Some(optional_vars) = item.optional_vars.as_deref() {
self.add_standalone_expression(&item.context_expr);
self.current_assignment = Some(item.into());
self.visit_expr(optional_vars);
self.current_assignment = None;
}
}
self.visit_body(body);
}
ast::Stmt::Break(_) => {
self.loop_break_states.push(self.flow_snapshot());
}
Expand Down Expand Up @@ -622,6 +636,15 @@ where
ComprehensionDefinitionNodeRef { node, first },
);
}
Some(CurrentAssignment::WithItem(with_item)) => {
self.add_definition(
symbol,
WithItemDefinitionNodeRef {
node: with_item,
target: name_node,
},
);
}
None => {}
}
}
Expand Down Expand Up @@ -778,6 +801,7 @@ enum CurrentAssignment<'a> {
node: &'a ast::Comprehension,
first: bool,
},
WithItem(&'a ast::WithItem),
}

impl<'a> From<&'a ast::StmtAssign> for CurrentAssignment<'a> {
Expand All @@ -803,3 +827,9 @@ impl<'a> From<&'a ast::ExprNamed> for CurrentAssignment<'a> {
Self::Named(value)
}
}

impl<'a> From<&'a ast::WithItem> for CurrentAssignment<'a> {
fn from(value: &'a ast::WithItem) -> Self {
Self::WithItem(value)
}
}
38 changes: 37 additions & 1 deletion crates/red_knot_python_semantic/src/semantic_index/definition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ pub(crate) enum DefinitionNodeRef<'a> {
AugmentedAssignment(&'a ast::StmtAugAssign),
Comprehension(ComprehensionDefinitionNodeRef<'a>),
Parameter(ast::AnyParameterRef<'a>),
WithItem(WithItemDefinitionNodeRef<'a>),
}

impl<'a> From<&'a ast::StmtFunctionDef> for DefinitionNodeRef<'a> {
Expand Down Expand Up @@ -97,6 +98,12 @@ impl<'a> From<AssignmentDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
}
}

impl<'a> From<WithItemDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
fn from(node_ref: WithItemDefinitionNodeRef<'a>) -> Self {
Self::WithItem(node_ref)
}
}

impl<'a> From<ComprehensionDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
fn from(node: ComprehensionDefinitionNodeRef<'a>) -> Self {
Self::Comprehension(node)
Expand All @@ -121,6 +128,12 @@ pub(crate) struct AssignmentDefinitionNodeRef<'a> {
pub(crate) target: &'a ast::ExprName,
}

#[derive(Copy, Clone, Debug)]
pub(crate) struct WithItemDefinitionNodeRef<'a> {
pub(crate) node: &'a ast::WithItem,
pub(crate) target: &'a ast::ExprName,
}

#[derive(Copy, Clone, Debug)]
pub(crate) struct ComprehensionDefinitionNodeRef<'a> {
pub(crate) node: &'a ast::Comprehension,
Expand Down Expand Up @@ -175,6 +188,12 @@ impl DefinitionNodeRef<'_> {
DefinitionKind::ParameterWithDefault(AstNodeRef::new(parsed, parameter))
}
},
DefinitionNodeRef::WithItem(WithItemDefinitionNodeRef { node, target }) => {
DefinitionKind::WithItem(WithItemDefinitionKind {
node: AstNodeRef::new(parsed.clone(), node),
target: AstNodeRef::new(parsed, target),
})
}
}
}

Expand All @@ -198,6 +217,7 @@ impl DefinitionNodeRef<'_> {
ast::AnyParameterRef::Variadic(parameter) => parameter.into(),
ast::AnyParameterRef::NonVariadic(parameter) => parameter.into(),
},
Self::WithItem(WithItemDefinitionNodeRef { node: _, target }) => target.into(),
}
}
}
Expand All @@ -215,6 +235,7 @@ pub enum DefinitionKind {
Comprehension(ComprehensionDefinitionKind),
Parameter(AstNodeRef<ast::Parameter>),
ParameterWithDefault(AstNodeRef<ast::ParameterWithDefault>),
WithItem(WithItemDefinitionKind),
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -250,7 +271,6 @@ impl ImportFromDefinitionKind {
}

#[derive(Clone, Debug)]
#[allow(dead_code)]
pub struct AssignmentDefinitionKind {
assignment: AstNodeRef<ast::StmtAssign>,
target: AstNodeRef<ast::ExprName>,
Expand All @@ -266,6 +286,22 @@ impl AssignmentDefinitionKind {
}
}

#[derive(Clone, Debug)]
pub struct WithItemDefinitionKind {
node: AstNodeRef<ast::WithItem>,
target: AstNodeRef<ast::ExprName>,
}

impl WithItemDefinitionKind {
pub(crate) fn node(&self) -> &ast::WithItem {
self.node.node()
}

pub(crate) fn target(&self) -> &ast::ExprName {
self.target.node()
}
}

#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
pub(crate) struct DefinitionNodeKey(NodeKey);

Expand Down
36 changes: 34 additions & 2 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,9 @@ impl<'db> TypeInferenceBuilder<'db> {
DefinitionKind::ParameterWithDefault(parameter_with_default) => {
self.infer_parameter_with_default_definition(parameter_with_default, definition);
}
DefinitionKind::WithItem(with_item) => {
self.infer_with_item_definition(with_item.target(), with_item.node(), definition);
}
}
}

Expand Down Expand Up @@ -618,13 +621,42 @@ impl<'db> TypeInferenceBuilder<'db> {
} = with_statement;

for item in items {
self.infer_expression(&item.context_expr);
self.infer_optional_expression(item.optional_vars.as_deref());
match item.optional_vars.as_deref() {
Some(ast::Expr::Name(name)) => {
self.infer_definition(name);
}
_ => {
// TODO infer definitions in unpacking assignment
self.infer_expression(&item.context_expr);
}
}
}

self.infer_body(body);
}

fn infer_with_item_definition(
&mut self,
target: &ast::ExprName,
with_item: &ast::WithItem,
definition: Definition<'db>,
) {
let expression = self.index.expression(&with_item.context_expr);
let result = infer_expression_types(self.db, expression);
self.extend(result);

// TODO(dhruvmanila): The correct type inference here is the return type of the __enter__
// method of the context manager.
let context_expr_ty = self
.types
.expression_ty(with_item.context_expr.scoped_ast_id(self.db, self.scope));

self.types
.expressions
.insert(target.scoped_ast_id(self.db, self.scope), context_expr_ty);
self.types.definitions.insert(definition, context_expr_ty);
}

fn infer_match_statement(&mut self, match_statement: &ast::StmtMatch) {
let ast::StmtMatch {
range: _,
Expand Down

0 comments on commit 8144a11

Please sign in to comment.