diff --git a/crates/red_knot_python_semantic/src/module_name.rs b/crates/red_knot_python_semantic/src/module_name.rs index 3aa280fea128e..885c6adf9f4bf 100644 --- a/crates/red_knot_python_semantic/src/module_name.rs +++ b/crates/red_knot_python_semantic/src/module_name.rs @@ -168,6 +168,24 @@ impl ModuleName { }; Some(Self(name)) } + + /// Extend `self` with the components of `other` + /// + /// # Examples + /// + /// ``` + /// use red_knot_python_semantic::ModuleName; + /// + /// let mut module_name = ModuleName::new_static("foo").unwrap(); + /// module_name.extend(&ModuleName::new_static("bar").unwrap()); + /// assert_eq!(&module_name, "foo.bar"); + /// module_name.extend(&ModuleName::new_static("baz.eggs.ham").unwrap()); + /// assert_eq!(&module_name, "foo.bar.baz.eggs.ham"); + /// ``` + pub fn extend(&mut self, other: &ModuleName) { + self.0.push('.'); + self.0.push_str(other); + } } impl Deref for ModuleName { diff --git a/crates/red_knot_python_semantic/src/module_resolver/mod.rs b/crates/red_knot_python_semantic/src/module_resolver/mod.rs index 06f13271f0819..93a34f7b62c65 100644 --- a/crates/red_knot_python_semantic/src/module_resolver/mod.rs +++ b/crates/red_knot_python_semantic/src/module_resolver/mod.rs @@ -2,7 +2,7 @@ use std::iter::FusedIterator; pub(crate) use module::Module; pub use resolver::resolve_module; -pub(crate) use resolver::SearchPaths; +pub(crate) use resolver::{file_to_module, SearchPaths}; use ruff_db::system::SystemPath; pub use typeshed::vendored_typeshed_stubs; diff --git a/crates/red_knot_python_semantic/src/module_resolver/module.rs b/crates/red_knot_python_semantic/src/module_resolver/module.rs index 9814dd715735b..e2c1e939572cc 100644 --- a/crates/red_knot_python_semantic/src/module_resolver/module.rs +++ b/crates/red_knot_python_semantic/src/module_resolver/module.rs @@ -77,3 +77,9 @@ pub enum ModuleKind { /// A python package (`foo/__init__.py` or `foo/__init__.pyi`) Package, } + +impl ModuleKind { + pub const fn is_package(self) -> bool { + matches!(self, ModuleKind::Package) + } +} diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 46b52ca62751d..6f494f9c6bf96 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -20,6 +20,8 @@ //! //! Inferring types at any of the three region granularities returns a [`TypeInference`], which //! holds types for every [`Definition`] and expression within the inferred region. +use std::num::NonZeroU32; + use rustc_hash::FxHashMap; use salsa; use salsa::plumbing::AsId; @@ -31,7 +33,7 @@ use ruff_python_ast::{ExprContext, TypeParams}; use crate::builtins::builtins_scope; use crate::module_name::ModuleName; -use crate::module_resolver::resolve_module; +use crate::module_resolver::{file_to_module, resolve_module}; use crate::semantic_index::ast_ids::{HasScopedAstId, HasScopedUseId, ScopedExpressionId}; use crate::semantic_index::definition::{Definition, DefinitionKind, DefinitionNodeKey}; use crate::semantic_index::expression::Expression; @@ -822,7 +824,7 @@ impl<'db> TypeInferenceBuilder<'db> { asname: _, } = alias; - let module_ty = self.module_ty_from_name(name); + let module_ty = self.module_ty_from_name(ModuleName::new(name)); self.types.definitions.insert(definition, module_ty); } @@ -860,20 +862,68 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_optional_expression(cause.as_deref()); } + /// Given a `from .foo import bar` relative import, resolve the relative module + /// we're importing `bar` from into an absolute [`ModuleName`] + /// using the name of the module we're currently analyzing. + /// + /// - `level` is the number of dots at the beginning of the relative module name: + /// - `from .foo.bar import baz` => `level == 1` + /// - `from ...foo.bar import baz` => `level == 3` + /// - `tail` is the relative module name stripped of all leading dots: + /// - `from .foo import bar` => `tail == "foo"` + /// - `from ..foo.bar import baz` => `tail == "foo.bar"` + fn relative_module_name(&self, tail: Option<&str>, level: NonZeroU32) -> Option { + let Some(module) = file_to_module(self.db, self.file) else { + tracing::debug!("Failed to resolve file {:?} to a module", self.file); + return None; + }; + let mut level = level.get(); + if module.kind().is_package() { + level -= 1; + } + let mut module_name = module.name().to_owned(); + for _ in 0..level { + module_name = module_name.parent()?; + } + if let Some(tail) = tail { + if let Some(valid_tail) = ModuleName::new(tail) { + module_name.extend(&valid_tail); + } else { + tracing::debug!("Failed to resolve relative import due to invalid syntax"); + return None; + } + } + Some(module_name) + } + fn infer_import_from_definition( &mut self, import_from: &ast::StmtImportFrom, alias: &ast::Alias, definition: Definition<'db>, ) { - let ast::StmtImportFrom { module, .. } = import_from; - let module_ty = if let Some(module) = module { - self.module_ty_from_name(module) + // TODO: + // - Absolute `*` imports (`from collections import *`) + // - Relative `*` imports (`from ...foo import *`) + // - Submodule imports (`from collections import abc`, + // where `abc` is a submodule of the `collections` package) + // + // For the last item, see the currently skipped tests + // `follow_relative_import_bare_to_module()` and + // `follow_nonexistent_import_bare_to_module()`. + let ast::StmtImportFrom { module, level, .. } = import_from; + tracing::trace!("Resolving imported object {alias:?} from statement {import_from:?}"); + let module_name = if let Some(level) = NonZeroU32::new(*level) { + self.relative_module_name(module.as_deref(), level) } else { - // TODO support relative imports - Type::Unknown + let module_name = module + .as_ref() + .expect("Non-relative import should always have a non-None `module`!"); + ModuleName::new(module_name) }; + let module_ty = self.module_ty_from_name(module_name); + let ast::Alias { range: _, name, @@ -896,11 +946,10 @@ impl<'db> TypeInferenceBuilder<'db> { } } - fn module_ty_from_name(&self, name: &ast::Identifier) -> Type<'db> { - let module = ModuleName::new(&name.id).and_then(|name| resolve_module(self.db, name)); - module - .map(|module| Type::Module(module.file())) - .unwrap_or(Type::Unbound) + fn module_ty_from_name(&self, module_name: Option) -> Type<'db> { + module_name + .and_then(|module_name| resolve_module(self.db, module_name)) + .map_or(Type::Unbound, |module| Type::Module(module.file())) } fn infer_decorator(&mut self, decorator: &ast::Decorator) -> Type<'db> { @@ -1710,6 +1759,148 @@ mod tests { Ok(()) } + #[test] + fn follow_relative_import_simple() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_files([ + ("src/package/__init__.py", ""), + ("src/package/foo.py", "X = 42"), + ("src/package/bar.py", "from .foo import X"), + ])?; + + assert_public_ty(&db, "src/package/bar.py", "X", "Literal[42]"); + + Ok(()) + } + + #[test] + fn follow_nonexistent_relative_import_simple() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_files([ + ("src/package/__init__.py", ""), + ("src/package/bar.py", "from .foo import X"), + ])?; + + assert_public_ty(&db, "src/package/bar.py", "X", "Unbound"); + + Ok(()) + } + + #[test] + fn follow_relative_import_dotted() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_files([ + ("src/package/__init__.py", ""), + ("src/package/foo/bar/baz.py", "X = 42"), + ("src/package/bar.py", "from .foo.bar.baz import X"), + ])?; + + assert_public_ty(&db, "src/package/bar.py", "X", "Literal[42]"); + + Ok(()) + } + + #[test] + fn follow_relative_import_bare_to_package() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_files([ + ("src/package/__init__.py", "X = 42"), + ("src/package/bar.py", "from . import X"), + ])?; + + assert_public_ty(&db, "src/package/bar.py", "X", "Literal[42]"); + + Ok(()) + } + + #[test] + fn follow_nonexistent_relative_import_bare_to_package() -> anyhow::Result<()> { + let mut db = setup_db(); + db.write_files([("src/package/bar.py", "from . import X")])?; + assert_public_ty(&db, "src/package/bar.py", "X", "Unbound"); + Ok(()) + } + + #[ignore = "TODO: Submodule imports possibly not supported right now?"] + #[test] + fn follow_relative_import_bare_to_module() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_files([ + ("src/package/__init__.py", ""), + ("src/package/foo.py", "X = 42"), + ("src/package/bar.py", "from . import foo; y = foo.X"), + ])?; + + assert_public_ty(&db, "src/package/bar.py", "y", "Literal[42]"); + + Ok(()) + } + + #[ignore = "TODO: Submodule imports possibly not supported right now?"] + #[test] + fn follow_nonexistent_import_bare_to_module() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_files([ + ("src/package/__init__.py", ""), + ("src/package/bar.py", "from . import foo"), + ])?; + + assert_public_ty(&db, "src/package/bar.py", "foo", "Unbound"); + + Ok(()) + } + + #[test] + fn follow_relative_import_from_dunder_init() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_files([ + ("src/package/__init__.py", "from .foo import X"), + ("src/package/foo.py", "X = 42"), + ])?; + + assert_public_ty(&db, "src/package/__init__.py", "X", "Literal[42]"); + + Ok(()) + } + + #[test] + fn follow_nonexistent_relative_import_from_dunder_init() -> anyhow::Result<()> { + let mut db = setup_db(); + db.write_files([("src/package/__init__.py", "from .foo import X")])?; + assert_public_ty(&db, "src/package/__init__.py", "X", "Unbound"); + Ok(()) + } + + #[test] + fn follow_very_relative_import() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_files([ + ("src/package/__init__.py", ""), + ("src/package/foo.py", "X = 42"), + ( + "src/package/subpackage/subsubpackage/bar.py", + "from ...foo import X", + ), + ])?; + + assert_public_ty( + &db, + "src/package/subpackage/subsubpackage/bar.py", + "X", + "Literal[42]", + ); + + Ok(()) + } + #[test] fn resolve_base_class_by_name() -> anyhow::Result<()> { let mut db = setup_db(); diff --git a/crates/ruff_benchmark/benches/red_knot.rs b/crates/ruff_benchmark/benches/red_knot.rs index 727b50a452ed4..f99a0fa06cc61 100644 --- a/crates/ruff_benchmark/benches/red_knot.rs +++ b/crates/ruff_benchmark/benches/red_knot.rs @@ -89,7 +89,7 @@ fn benchmark_incremental(criterion: &mut Criterion) { let Case { db, parser, .. } = case; let result = db.check_file(*parser).unwrap(); - assert_eq!(result.len(), 111); + assert_eq!(result.len(), 29); }, BatchSize::SmallInput, ); @@ -104,7 +104,7 @@ fn benchmark_cold(criterion: &mut Criterion) { let Case { db, parser, .. } = case; let result = db.check_file(*parser).unwrap(); - assert_eq!(result.len(), 111); + assert_eq!(result.len(), 29); }, BatchSize::SmallInput, );