Skip to content

Commit

Permalink
[red-knot] add type narrowing
Browse files Browse the repository at this point in the history
  • Loading branch information
carljm committed Aug 15, 2024
1 parent 73160dc commit ff18246
Show file tree
Hide file tree
Showing 9 changed files with 1,143 additions and 176 deletions.
67 changes: 43 additions & 24 deletions crates/red_knot_python_semantic/src/semantic_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,18 @@ use crate::semantic_index::expression::Expression;
use crate::semantic_index::symbol::{
FileScopeId, NodeWithScopeKey, NodeWithScopeRef, Scope, ScopeId, ScopedSymbolId, SymbolTable,
};
use crate::semantic_index::use_def::UseDefMap;
use crate::Db;

pub(crate) use self::use_def::UseDefMap;

pub mod ast_ids;
mod builder;
pub mod definition;
pub mod expression;
pub mod symbol;
mod use_def;

pub(crate) use self::use_def::{DefinitionWithConstraints, DefinitionWithConstraintsIterator};

type SymbolMap = hashbrown::HashMap<ScopedSymbolId, (), ()>;

/// Returns the semantic index for `file`.
Expand Down Expand Up @@ -313,6 +314,7 @@ mod tests {
use crate::semantic_index::ast_ids::HasScopedUseId;
use crate::semantic_index::definition::DefinitionKind;
use crate::semantic_index::symbol::{FileScopeId, Scope, ScopeKind, SymbolTable};
use crate::semantic_index::use_def::DefinitionWithConstraints;
use crate::semantic_index::{global_scope, semantic_index, symbol_table, use_def_map};
use crate::Db;

Expand Down Expand Up @@ -374,7 +376,9 @@ mod tests {
let foo = global_table.symbol_id_by_name("foo").unwrap();

let use_def = use_def_map(&db, scope);
let [definition] = use_def.public_definitions(foo) else {
let [DefinitionWithConstraints { definition, .. }] =
use_def.public_definitions(foo).collect::<Vec<_>>()[..]
else {
panic!("expected one definition");
};
assert!(matches!(definition.node(&db), DefinitionKind::Import(_)));
Expand Down Expand Up @@ -411,11 +415,14 @@ mod tests {
);

let use_def = use_def_map(&db, scope);
let [definition] = use_def.public_definitions(
global_table
.symbol_id_by_name("foo")
.expect("symbol to exist"),
) else {
let [DefinitionWithConstraints { definition, .. }] = use_def
.public_definitions(
global_table
.symbol_id_by_name("foo")
.expect("symbol to exist"),
)
.collect::<Vec<_>>()[..]
else {
panic!("expected one definition");
};
assert!(matches!(
Expand All @@ -438,8 +445,9 @@ mod tests {
"a symbol used but not defined in a scope should have only the used flag"
);
let use_def = use_def_map(&db, scope);
let [definition] =
use_def.public_definitions(global_table.symbol_id_by_name("x").expect("symbol exists"))
let [DefinitionWithConstraints { definition, .. }] = use_def
.public_definitions(global_table.symbol_id_by_name("x").expect("symbol exists"))
.collect::<Vec<_>>()[..]
else {
panic!("expected one definition");
};
Expand Down Expand Up @@ -477,8 +485,9 @@ y = 2
assert_eq!(names(&class_table), vec!["x"]);

let use_def = index.use_def_map(class_scope_id);
let [definition] =
use_def.public_definitions(class_table.symbol_id_by_name("x").expect("symbol exists"))
let [DefinitionWithConstraints { definition, .. }] = use_def
.public_definitions(class_table.symbol_id_by_name("x").expect("symbol exists"))
.collect::<Vec<_>>()[..]
else {
panic!("expected one definition");
};
Expand Down Expand Up @@ -515,11 +524,14 @@ y = 2
assert_eq!(names(&function_table), vec!["x"]);

let use_def = index.use_def_map(function_scope_id);
let [definition] = use_def.public_definitions(
function_table
.symbol_id_by_name("x")
.expect("symbol exists"),
) else {
let [DefinitionWithConstraints { definition, .. }] = use_def
.public_definitions(
function_table
.symbol_id_by_name("x")
.expect("symbol exists"),
)
.collect::<Vec<_>>()[..]
else {
panic!("expected one definition");
};
assert!(matches!(
Expand Down Expand Up @@ -594,7 +606,9 @@ y = 2
let element_use_id =
element.scoped_use_id(&db, comprehension_scope_id.to_scope_id(&db, file));

let [definition] = use_def.use_definitions(element_use_id) else {
let [DefinitionWithConstraints { definition, .. }] =
use_def.use_definitions(element_use_id).collect::<Vec<_>>()[..]
else {
panic!("expected one definition")
};
let DefinitionKind::Comprehension(comprehension) = definition.node(&db) else {
Expand Down Expand Up @@ -693,11 +707,14 @@ def func():
assert_eq!(names(&func2_table), vec!["y"]);

let use_def = index.use_def_map(FileScopeId::global());
let [definition] = use_def.public_definitions(
global_table
.symbol_id_by_name("func")
.expect("symbol exists"),
) else {
let [DefinitionWithConstraints { definition, .. }] = use_def
.public_definitions(
global_table
.symbol_id_by_name("func")
.expect("symbol exists"),
)
.collect::<Vec<_>>()[..]
else {
panic!("expected one definition");
};
assert!(matches!(definition.node(&db), DefinitionKind::Function(_)));
Expand Down Expand Up @@ -800,7 +817,9 @@ class C[T]:
};
let x_use_id = x_use_expr_name.scoped_use_id(&db, scope);
let use_def = use_def_map(&db, scope);
let [definition] = use_def.use_definitions(x_use_id) else {
let [DefinitionWithConstraints { definition, .. }] =
use_def.use_definitions(x_use_id).collect::<Vec<_>>()[..]
else {
panic!("expected one definition");
};
let DefinitionKind::Assignment(assignment) = definition.node(&db) else {
Expand Down
11 changes: 10 additions & 1 deletion crates/red_knot_python_semantic/src/semantic_index/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,16 @@ impl<'db> SemanticIndexBuilder<'db> {
definition
}

fn add_constraint(&mut self, constraint_node: &ast::Expr) -> Expression<'db> {
let expression = self.add_standalone_expression(constraint_node);
self.current_use_def_map_mut().record_constraint(expression);

expression
}

/// Record an expression that needs to be a Salsa ingredient, because we need to infer its type
/// standalone (type narrowing tests, RHS of an assignment.)
fn add_standalone_expression(&mut self, expression_node: &ast::Expr) {
fn add_standalone_expression(&mut self, expression_node: &ast::Expr) -> Expression<'db> {
let expression = Expression::new(
self.db,
self.file,
Expand All @@ -210,6 +217,7 @@ impl<'db> SemanticIndexBuilder<'db> {
);
self.expressions_by_node
.insert(expression_node.into(), expression);
expression
}

fn with_type_params(
Expand Down Expand Up @@ -456,6 +464,7 @@ where
ast::Stmt::If(node) => {
self.visit_expr(&node.test);
let pre_if = self.flow_snapshot();
self.add_constraint(&node.test);
self.visit_body(&node.body);
let mut post_clauses: Vec<FlowSnapshot> = vec![];
for clause in &node.elif_else_clauses {
Expand Down
Loading

0 comments on commit ff18246

Please sign in to comment.