Skip to content

Commit

Permalink
[red-knot] Add type inference for basic for loops
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexWaygood committed Sep 1, 2024
1 parent 2014cba commit b46aecf
Show file tree
Hide file tree
Showing 6 changed files with 262 additions and 55 deletions.
16 changes: 0 additions & 16 deletions crates/red_knot_python_semantic/src/builtins.rs

This file was deleted.

38 changes: 38 additions & 0 deletions crates/red_knot_python_semantic/src/core_stdlib_modules.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use crate::module_name::ModuleName;
use crate::module_resolver::resolve_module;
use crate::semantic_index::global_scope;
use crate::semantic_index::symbol::ScopeId;
use crate::Db;

/// Salsa query to get the builtins scope.
///
/// Can return None if a custom typeshed is used that is missing `builtins.pyi`.
#[salsa::tracked]
pub(crate) fn builtins_scope(db: &dyn Db) -> Option<ScopeId<'_>> {
let builtins_name =
ModuleName::new_static("builtins").expect("Expected 'builtins' to be a valid module name");
let builtins_file = resolve_module(db, builtins_name)?.file();
Some(global_scope(db, builtins_file))
}

/// Salsa query to get the scope for the `types` module.
///
/// Can return None if a custom typeshed is used that is missing `types.pyi`.
#[salsa::tracked]
pub(crate) fn types_scope(db: &dyn Db) -> Option<ScopeId<'_>> {
let types_module_name =
ModuleName::new_static("types").expect("Expected 'types' to be a valid module name");
let types_file = resolve_module(db, types_module_name)?.file();
Some(global_scope(db, types_file))
}

/// Salsa query to get the scope for the `_typeshed` module.
///
/// Can return None if a custom typeshed is used that is missing a `_typeshed` directory.
#[salsa::tracked]
pub(crate) fn typeshed_scope(db: &dyn Db) -> Option<ScopeId<'_>> {
let typeshed_module_name = ModuleName::new_static("_typeshed")
.expect("Expected '_typeshed' to be a valid module name");
let typeshed_file = resolve_module(db, typeshed_module_name)?.file();
Some(global_scope(db, typeshed_file))
}
2 changes: 1 addition & 1 deletion crates/red_knot_python_semantic/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub use python_version::PythonVersion;
pub use semantic_model::{HasTy, SemanticModel};

pub mod ast_node_ref;
mod builtins;
mod core_stdlib_modules;
mod db;
mod module_name;
mod module_resolver;
Expand Down
183 changes: 163 additions & 20 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use ruff_db::files::File;
use ruff_python_ast as ast;

use crate::builtins::builtins_scope;
use crate::core_stdlib_modules::{builtins_scope, types_scope, typeshed_scope};
use crate::semantic_index::ast_ids::HasScopedAstId;
use crate::semantic_index::definition::{Definition, DefinitionKind};
use crate::semantic_index::symbol::{ScopeId, ScopedSymbolId};
Expand Down Expand Up @@ -75,13 +75,44 @@ pub(crate) fn global_symbol_ty_by_name<'db>(db: &'db dyn Db, file: File, name: &
symbol_ty_by_name(db, global_scope(db, file), name)
}

/// Shorthand for `symbol_ty` that looks up a symbol in the builtins.
/// Enumeration of various core stdlib modules, for which we have dedicated Salsa queries.
///
/// Returns `Unbound` if the builtins module isn't available for some reason.
/// Things will start getting very strange during type-checking if one of these doesn't exist.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum CoreStdlibModule {
Builtins,
Types,
Typeshed,
}

impl CoreStdlibModule {
/// Retrieve the global scope of the given module.
///
/// Returns `None` if the given module isn't available for some reason.
pub(crate) fn global_scope(self, db: &dyn Db) -> Option<ScopeId<'_>> {
match self {
Self::Builtins => builtins_scope(db),
Self::Types => types_scope(db),
Self::Typeshed => typeshed_scope(db),
}
}

/// Shorthand for `symbol_ty` that looks up a symbol in the scope of a given core module.
///
/// Returns `Unbound` if the given module isn't available for some reason.
pub(crate) fn symbol_ty_by_name<'db>(self, db: &'db dyn Db, name: &str) -> Type<'db> {
self.global_scope(db)
.map(|globals| symbol_ty_by_name(db, globals, name))
.unwrap_or(Type::Unbound)
}
}

/// Shorthand for `symbol_ty` that looks up a symbol in the `buitlins` scope.
///
/// Returns `Unbound` if the `builtins` module isn't available for some reason.
#[inline]
pub(crate) fn builtins_symbol_ty_by_name<'db>(db: &'db dyn Db, name: &str) -> Type<'db> {
builtins_scope(db)
.map(|builtins| symbol_ty_by_name(db, builtins, name))
.unwrap_or(Type::Unbound)
CoreStdlibModule::Builtins.symbol_ty_by_name(db, name)
}

/// Infer the type of a [`Definition`].
Expand Down Expand Up @@ -238,13 +269,9 @@ impl<'db> Type<'db> {
pub fn replace_unbound_with(&self, db: &'db dyn Db, replacement: Type<'db>) -> Type<'db> {
match self {
Type::Unbound => replacement,
Type::Union(union) => union
.elements(db)
.into_iter()
.fold(UnionBuilder::new(db), |builder, ty| {
builder.add(ty.replace_unbound_with(db, replacement))
})
.build(),
Type::Union(union) => {
union.transform(db, |ty| ty.replace_unbound_with(db, replacement))
}
ty => *ty,
}
}
Expand Down Expand Up @@ -286,13 +313,7 @@ impl<'db> Type<'db> {
// TODO MRO? get_own_instance_member, get_instance_member
Type::Unknown
}
Type::Union(union) => union
.elements(db)
.iter()
.fold(UnionBuilder::new(db), |builder, element_ty| {
builder.add(element_ty.member(db, name))
})
.build(),
Type::Union(union) => union.transform(db, |ty| ty.member(db, name)),
Type::Intersection(_) => {
// TODO perform the get_member on each type in the intersection
// TODO return the intersection of those results
Expand Down Expand Up @@ -347,6 +368,36 @@ impl<'db> Type<'db> {
}
}

/// Given the type of an object that is iterated over in some way,
/// return the type of objects that are yielded by that iteration.
///
/// E.g., for the following loop, given the type of `x`, infer the type of `y`:
/// ```python
/// for y in x:
/// pass
/// ```
///
/// Returns `None` if `self` represents a type that is not iterable.
fn iterate(&self, db: &'db dyn Db) -> Option<Type<'db>> {
// `self` represents the type of the iterable;
// `__iter__` and `__next__` are both looked up on the class of the iterable:
let type_of_class = self.to_type_of_class(db);
let dunder_iter_method = type_of_class.member(db, &ast::name::Name::from("__iter__"));
if !dunder_iter_method.is_unbound() {
let iterator_ty = dunder_iter_method.call(db)?;
let dunder_next_method = iterator_ty
.to_type_of_class(db)
.member(db, &ast::name::Name::from("__next__"));
return dunder_next_method.call(db);
}
// Although it's not considered great practice,
// classes that define `__getitem__` are also iterable,
// even if they do not define `__iter__`:
let dunder_get_item_method =
type_of_class.member(db, &ast::name::Name::from("__getitem__"));
dunder_get_item_method.call(db)
}

#[must_use]
pub fn instance(&self) -> Type<'db> {
match self {
Expand All @@ -356,6 +407,34 @@ impl<'db> Type<'db> {
_ => Type::Unknown, // TODO type errors
}
}

/// Given a type that is assumed to represent an instance of a class,
/// return a type that represents that class itself.
#[must_use]
pub fn to_type_of_class(&self, db: &'db dyn Db) -> Type<'db> {
match self {
Type::Unbound => Type::Unbound,
Type::Never => Type::Never,
Type::Instance(class) => Type::Class(*class),
Type::Union(union) => union.transform(db, |ty| ty.to_type_of_class(db)),
Type::BooleanLiteral(_) => builtins_symbol_ty_by_name(db, "bool"),
Type::BytesLiteral(_) => builtins_symbol_ty_by_name(db, "bytes"),
Type::IntLiteral(_) => builtins_symbol_ty_by_name(db, "int"),
Type::Function(_) => CoreStdlibModule::Types.symbol_ty_by_name(db, "FunctionType"),
Type::Module(_) => CoreStdlibModule::Types.symbol_ty_by_name(db, "ModuleType"),
Type::None => CoreStdlibModule::Typeshed.symbol_ty_by_name(db, "NoneType"),
// TODO not accurate if there's a custom metaclass...
Type::Class(_) => builtins_symbol_ty_by_name(db, "type"),
// TODO can we do better here? `type[LiteralString]`?
Type::StringLiteral(_) | Type::LiteralString => builtins_symbol_ty_by_name(db, "str"),
// TODO: `type[Any]`?
Type::Any => Type::Any,
// TODO: `type[Unknown]`?
Type::Unknown => Type::Unknown,
// TODO intersections
Type::Intersection(_) => Type::Unknown,
}
}
}

#[salsa::interned]
Expand Down Expand Up @@ -471,6 +550,21 @@ impl<'db> UnionType<'db> {
pub fn contains(&self, db: &'db dyn Db, ty: Type<'db>) -> bool {
self.elements(db).contains(&ty)
}

/// Apply a transformation function to all elements of the union,
/// and create a new union from the resulting set of types
pub fn transform(
&self,
db: &'db dyn Db,
mut transform_fn: impl FnMut(&Type<'db>) -> Type<'db>,
) -> Type<'db> {
self.elements(db)
.into_iter()
.fold(UnionBuilder::new(db), |builder, ty| {
builder.add(transform_fn(&ty))
})
.build()
}
}

#[salsa::interned]
Expand Down Expand Up @@ -615,4 +709,53 @@ mod tests {
&["Object of type 'Literal[123]' is not callable"],
);
}

#[test]
fn invalid_iterable() {
let mut db = setup_db();

db.write_dedented(
"src/a.py",
"
nonsense = 123
for x in nonsense:
pass
",
)
.unwrap();

let a_file = system_path_to_file(&db, "/src/a.py").unwrap();
let a_file_diagnostics = super::check_types(&db, a_file);
assert_diagnostic_messages(
&a_file_diagnostics,
&["Object of type 'Literal[123]' is not iterable"],
);
}

#[test]
fn new_iteration_protocol_takes_precedence_over_old_style() {
let mut db = setup_db();

db.write_dedented(
"src/a.py",
"
class NotIterable:
def __getitem__(self, key: int) -> int:
return 42
__iter__ = None
for x in NotIterable():
pass
",
)
.unwrap();

let a_file = system_path_to_file(&db, "/src/a.py").unwrap();
let a_file_diagnostics = super::check_types(&db, a_file);
assert_diagnostic_messages(
&a_file_diagnostics,
&["Object of type 'NotIterable' is not iterable"],
);
}
}
4 changes: 1 addition & 3 deletions crates/red_knot_python_semantic/src/types/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,9 @@
//! * No type in an intersection can be a supertype of any other type in the intersection (just
//! eliminate the supertype from the intersection).
//! * An intersection containing two non-overlapping types should simplify to [`Type::Never`].
use crate::types::{IntersectionType, Type, UnionType};
use crate::types::{builtins_symbol_ty_by_name, IntersectionType, Type, UnionType};
use crate::{Db, FxOrderSet};

use super::builtins_symbol_ty_by_name;

pub(crate) struct UnionBuilder<'db> {
elements: FxOrderSet<Type<'db>>,
db: &'db dyn Db,
Expand Down
Loading

0 comments on commit b46aecf

Please sign in to comment.