diff --git a/crates/red_knot/src/symbols.rs b/crates/red_knot/src/symbols.rs index 182bb5f79cf9d..2406b414ed356 100644 --- a/crates/red_knot/src/symbols.rs +++ b/crates/red_knot/src/symbols.rs @@ -68,7 +68,7 @@ pub(crate) struct Scope { name: Name, kind: ScopeKind, child_scopes: Vec, - // symbol IDs, hashed by symbol name + /// symbol IDs, hashed by symbol name symbols_by_name: Map, } @@ -107,6 +107,7 @@ bitflags! { pub(crate) struct Symbol { name: Name, flags: SymbolFlags, + scope_id: ScopeId, // kind: Kind, } @@ -141,7 +142,7 @@ pub(crate) enum Definition { // the small amount of information we need from the AST. Import(ImportDefinition), ImportFrom(ImportFromDefinition), - ClassDef(TypedNodeKey), + ClassDef(ClassDefinition), FunctionDef(TypedNodeKey), Assignment(TypedNodeKey), AnnotatedAssignment(TypedNodeKey), @@ -174,6 +175,12 @@ impl ImportFromDefinition { } } +#[derive(Clone, Debug)] +pub(crate) struct ClassDefinition { + pub(crate) node_key: TypedNodeKey, + pub(crate) scope_id: ScopeId, +} + #[derive(Debug, Clone)] pub enum Dependency { Module(ModuleName), @@ -332,7 +339,11 @@ impl SymbolTable { *entry.key() } RawEntryMut::Vacant(entry) => { - let id = self.symbols_by_id.push(Symbol { name, flags }); + let id = self.symbols_by_id.push(Symbol { + name, + flags, + scope_id, + }); entry.insert_with_hasher(hash, id, (), |_| hash); id } @@ -459,8 +470,8 @@ impl SymbolTableBuilder { symbol_id } - fn push_scope(&mut self, child_of: ScopeId, name: &str, kind: ScopeKind) -> ScopeId { - let scope_id = self.table.add_child_scope(child_of, name, kind); + fn push_scope(&mut self, name: &str, kind: ScopeKind) -> ScopeId { + let scope_id = self.table.add_child_scope(self.cur_scope(), name, kind); self.scopes.push(scope_id); scope_id } @@ -482,10 +493,10 @@ impl SymbolTableBuilder { &mut self, name: &str, params: &Option>, - nested: impl FnOnce(&mut Self), - ) { + nested: impl FnOnce(&mut Self) -> ScopeId, + ) -> ScopeId { if let Some(type_params) = params { - self.push_scope(self.cur_scope(), name, ScopeKind::Annotation); + self.push_scope(name, ScopeKind::Annotation); for type_param in &type_params.type_params { let name = match type_param { ast::TypeParam::TypeVar(ast::TypeParamTypeVar { name, .. }) => name, @@ -495,10 +506,11 @@ impl SymbolTableBuilder { self.add_or_update_symbol(name, SymbolFlags::IS_DEFINED); } } - nested(self); + let scope_id = nested(self); if params.is_some() { self.pop_scope(); } + scope_id } } @@ -525,21 +537,26 @@ impl PreorderVisitor<'_> for SymbolTableBuilder { // TODO need to capture more definition statements here match stmt { ast::Stmt::ClassDef(node) => { - let def = Definition::ClassDef(TypedNodeKey::from_node(node)); - self.add_or_update_symbol_with_def(&node.name, def); - self.with_type_params(&node.name, &node.type_params, |builder| { - builder.push_scope(builder.cur_scope(), &node.name, ScopeKind::Class); + let scope_id = self.with_type_params(&node.name, &node.type_params, |builder| { + let scope_id = builder.push_scope(&node.name, ScopeKind::Class); ast::visitor::preorder::walk_stmt(builder, stmt); builder.pop_scope(); + scope_id }); + let def = Definition::ClassDef(ClassDefinition { + node_key: TypedNodeKey::from_node(node), + scope_id, + }); + self.add_or_update_symbol_with_def(&node.name, def); } ast::Stmt::FunctionDef(node) => { let def = Definition::FunctionDef(TypedNodeKey::from_node(node)); self.add_or_update_symbol_with_def(&node.name, def); self.with_type_params(&node.name, &node.type_params, |builder| { - builder.push_scope(builder.cur_scope(), &node.name, ScopeKind::Function); + let scope_id = builder.push_scope(&node.name, ScopeKind::Function); ast::visitor::preorder::walk_stmt(builder, stmt); builder.pop_scope(); + scope_id }); } ast::Stmt::Import(ast::StmtImport { names, .. }) => { diff --git a/crates/red_knot/src/types.rs b/crates/red_knot/src/types.rs index bce8eff114148..871c43092bb4c 100644 --- a/crates/red_knot/src/types.rs +++ b/crates/red_knot/src/types.rs @@ -1,7 +1,8 @@ #![allow(dead_code)] use crate::ast_ids::NodeKey; +use crate::db::{HasJar, QueryResult, SemanticDb, SemanticJar}; use crate::files::FileId; -use crate::symbols::SymbolId; +use crate::symbols::{ScopeId, SymbolId}; use crate::{FxDashMap, FxIndexSet, Name}; use ruff_index::{newtype_index, IndexVec}; use rustc_hash::FxHashMap; @@ -124,8 +125,15 @@ impl TypeStore { .add_function(name, decorators) } - fn add_class(&self, file_id: FileId, name: &str, bases: Vec) -> ClassTypeId { - self.add_or_get_module(file_id).add_class(name, bases) + fn add_class( + &self, + file_id: FileId, + name: &str, + scope_id: ScopeId, + bases: Vec, + ) -> ClassTypeId { + self.add_or_get_module(file_id) + .add_class(name, scope_id, bases) } fn add_union(&mut self, file_id: FileId, elems: &[Type]) -> UnionTypeId { @@ -253,6 +261,24 @@ pub struct ClassTypeId { class_id: ModuleClassTypeId, } +impl ClassTypeId { + fn get_own_class_member(self, db: &Db, name: &Name) -> QueryResult> + where + Db: SemanticDb + HasJar, + { + // TODO: this should distinguish instance-only members (e.g. `x: int`) and not return them + let ClassType { scope_id, .. } = *db.jar()?.type_store.get_class(self); + let table = db.symbol_table(self.file_id)?; + if let Some(symbol_id) = table.symbol_id_by_name(scope_id, name) { + Ok(Some(db.infer_symbol_type(self.file_id, symbol_id)?)) + } else { + Ok(None) + } + } + + // TODO: get_own_instance_member, get_class_member, get_instance_member +} + #[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)] pub struct UnionTypeId { file_id: FileId, @@ -318,9 +344,10 @@ impl ModuleTypeStore { } } - fn add_class(&mut self, name: &str, bases: Vec) -> ClassTypeId { + fn add_class(&mut self, name: &str, scope_id: ScopeId, bases: Vec) -> ClassTypeId { let class_id = self.classes.push(ClassType { name: Name::new(name), + scope_id, // TODO: if no bases are given, that should imply [object] bases, }); @@ -405,7 +432,11 @@ impl std::fmt::Display for DisplayType<'_> { #[derive(Debug)] pub(crate) struct ClassType { + /// Name of the class at definition name: Name, + /// `ScopeId` of the class body + pub(crate) scope_id: ScopeId, + /// Types of all class bases bases: Vec, } @@ -496,6 +527,7 @@ impl IntersectionType { #[cfg(test)] mod tests { use crate::files::Files; + use crate::symbols::SymbolTable; use crate::types::{Type, TypeStore}; use crate::FxIndexSet; use std::path::Path; @@ -505,7 +537,7 @@ mod tests { let store = TypeStore::default(); let files = Files::default(); let file_id = files.intern(Path::new("/foo")); - let id = store.add_class(file_id, "C", Vec::new()); + let id = store.add_class(file_id, "C", SymbolTable::root_scope_id(), Vec::new()); assert_eq!(store.get_class(id).name(), "C"); let inst = Type::Instance(id); assert_eq!(format!("{}", inst.display(&store)), "C"); @@ -528,8 +560,8 @@ mod tests { let mut store = TypeStore::default(); let files = Files::default(); let file_id = files.intern(Path::new("/foo")); - let c1 = store.add_class(file_id, "C1", Vec::new()); - let c2 = store.add_class(file_id, "C2", Vec::new()); + let c1 = store.add_class(file_id, "C1", SymbolTable::root_scope_id(), Vec::new()); + let c2 = store.add_class(file_id, "C2", SymbolTable::root_scope_id(), Vec::new()); let elems = vec![Type::Instance(c1), Type::Instance(c2)]; let id = store.add_union(file_id, &elems); assert_eq!( @@ -545,9 +577,9 @@ mod tests { let mut store = TypeStore::default(); let files = Files::default(); let file_id = files.intern(Path::new("/foo")); - let c1 = store.add_class(file_id, "C1", Vec::new()); - let c2 = store.add_class(file_id, "C2", Vec::new()); - let c3 = store.add_class(file_id, "C3", Vec::new()); + let c1 = store.add_class(file_id, "C1", SymbolTable::root_scope_id(), Vec::new()); + let c2 = store.add_class(file_id, "C2", SymbolTable::root_scope_id(), Vec::new()); + let c3 = store.add_class(file_id, "C3", SymbolTable::root_scope_id(), Vec::new()); let pos = vec![Type::Instance(c1), Type::Instance(c2)]; let neg = vec![Type::Instance(c3)]; let id = store.add_intersection(file_id, &pos, &neg); diff --git a/crates/red_knot/src/types/infer.rs b/crates/red_knot/src/types/infer.rs index f7e890fbb1eb3..efdd3a484cb6a 100644 --- a/crates/red_knot/src/types/infer.rs +++ b/crates/red_knot/src/types/infer.rs @@ -4,7 +4,7 @@ use ruff_python_ast::AstNode; use crate::db::{HasJar, QueryResult, SemanticDb, SemanticJar}; use crate::module::ModuleName; -use crate::symbols::{Definition, ImportFromDefinition, SymbolId}; +use crate::symbols::{ClassDefinition, Definition, ImportFromDefinition, SymbolId}; use crate::types::Type; use crate::FileId; use ruff_python_ast as ast; @@ -51,7 +51,7 @@ where Type::Unknown } } - Definition::ClassDef(node_key) => { + Definition::ClassDef(ClassDefinition { node_key, scope_id }) => { if let Some(ty) = type_store.get_cached_node_type(file_id, node_key.erased()) { ty } else { @@ -65,7 +65,8 @@ where bases.push(infer_expr_type(db, file_id, base)?); } - let ty = Type::Class(type_store.add_class(file_id, &node.name.id, bases)); + let ty = + Type::Class(type_store.add_class(file_id, &node.name.id, *scope_id, bases)); type_store.cache_node_type(file_id, *node_key.erased(), ty); ty } @@ -133,6 +134,7 @@ mod tests { use crate::db::{HasJar, SemanticDb, SemanticJar}; use crate::module::{ModuleName, ModuleSearchPath, ModuleSearchPathKind}; use crate::types::Type; + use crate::Name; // TODO with virtual filesystem we shouldn't have to write files to disk for these // tests @@ -222,4 +224,42 @@ mod tests { Ok(()) } + + #[test] + fn resolve_method() -> anyhow::Result<()> { + let case = create_test()?; + let db = &case.db; + + let path = case.src.path().join("mod.py"); + std::fs::write(path, "class C:\n def f(self): pass")?; + let file = db + .resolve_module(ModuleName::new("mod"))? + .expect("module should be found") + .path(db)? + .file(); + let syms = db.symbol_table(file)?; + let sym = syms + .root_symbol_id_by_name("C") + .expect("C symbol should be found"); + + let ty = db.infer_symbol_type(file, sym)?; + + let Type::Class(class_id) = ty else { + panic!("C is not a Class"); + }; + + let member_ty = class_id + .get_own_class_member(db, &Name::new("f")) + .expect("C.f to resolve"); + + let Some(Type::Function(func_id)) = member_ty else { + panic!("C.f is not a Function"); + }; + + let jar = HasJar::::jar(db)?; + let function = jar.type_store.get_function(func_id); + assert_eq!(function.name(), "f"); + + Ok(()) + } }