WIP: Jack + Codex mucking around with loop control flow (The `unsafe` code here is obviously not going to survive review.)
diff --git a/crates/ty_python_semantic/resources/mdtest/loops/while_loop.md b/crates/ty_python_semantic/resources/mdtest/loops/while_loop.md index 41e48b1..9d4463e 100644 --- a/crates/ty_python_semantic/resources/mdtest/loops/while_loop.md +++ b/crates/ty_python_semantic/resources/mdtest/loops/while_loop.md
@@ -127,3 +127,64 @@ while NotBoolable(): ... ``` + +## Backwards control flow + +```py +i = 0 +reveal_type(i) # revealed: Literal[0] +while i < 1_000_000: + reveal_type(i) # revealed: int + i += 1 + reveal_type(i) # revealed: int +reveal_type(i) # revealed: int + +# TODO: None of these should need to be raised to `int`. Loop control flow analysis should take the +# loop condition into account. +i = 0 +reveal_type(i) # revealed: Literal[0] +while i < 2: + # TODO: Should be Literal[0, 1]. + reveal_type(i) # revealed: int + i += 1 + # TODO: Should be Literal[1, 2]. + reveal_type(i) # revealed: int +# TODO: Should be Literal[2]. +reveal_type(i) # revealed: int +``` + +```py +def random() -> bool: + raise NotImplementedError + +i = 0 +while True: + reveal_type(i) # revealed: Literal[0, 1, 2] + if random(): + i = 1 + else: + i = "break" + break + # To get here we must take the `i = 1` branch above. + reveal_type(i) # revealed: Literal[1] + if random(): + i = 2 + reveal_type(i) # revealed: Literal[1, 2] +reveal_type(i) # revealed: Literal["break"] + +i = 0 +while random(): + if random(): + reveal_type(i) # revealed: Literal[0, 1, 2, 3] + i = 1 + reveal_type(i) # revealed: Literal[1] + while random(): + if random(): + reveal_type(i) # revealed: Literal[1, 0, 2, 3] + i = 2 + reveal_type(i) # revealed: Literal[2] + if random(): + reveal_type(i) # revealed: Literal[1, 2, 0, 3] + i = 3 + reveal_type(i) # revealed: Literal[3] +```
diff --git a/crates/ty_python_semantic/src/semantic_index.rs b/crates/ty_python_semantic/src/semantic_index.rs index 8642dcf..0045590 100644 --- a/crates/ty_python_semantic/src/semantic_index.rs +++ b/crates/ty_python_semantic/src/semantic_index.rs
@@ -229,6 +229,9 @@ /// Map from a standalone expression to its [`Expression`] ingredient. expressions_by_node: FxHashMap<ExpressionNodeKey, Expression<'db>>, + /// Map from loop-header definitions to their constituent definitions. + loop_header_definitions: FxHashMap<Definition<'db>, Vec<Definition<'db>>>, + /// Map from nodes that create a scope to the scope they create. scopes_by_node: FxHashMap<NodeWithScopeKey, FileScopeId>, @@ -319,6 +322,15 @@ self.scope_ids_by_scope.iter().copied() } + pub(crate) fn loop_header_definitions( + &self, + definition: Definition<'db>, + ) -> Option<&[Definition<'db>]> { + self.loop_header_definitions + .get(&definition) + .map(|defs| defs.as_slice()) + } + pub(crate) fn symbol_is_global_in_scope( &self, symbol: ScopedSymbolId,
diff --git a/crates/ty_python_semantic/src/semantic_index/ast_ids.rs b/crates/ty_python_semantic/src/semantic_index/ast_ids.rs index cc2c655..d3cf7b5 100644 --- a/crates/ty_python_semantic/src/semantic_index/ast_ids.rs +++ b/crates/ty_python_semantic/src/semantic_index/ast_ids.rs
@@ -141,4 +141,10 @@ Self(NodeKey::from_node(value)) } } + + impl From<ExpressionNodeKey> for NodeKey { + fn from(value: ExpressionNodeKey) -> Self { + value.0 + } + } }
diff --git a/crates/ty_python_semantic/src/semantic_index/builder.rs b/crates/ty_python_semantic/src/semantic_index/builder.rs index 3ef83e9..220e581 100644 --- a/crates/ty_python_semantic/src/semantic_index/builder.rs +++ b/crates/ty_python_semantic/src/semantic_index/builder.rs
@@ -19,14 +19,15 @@ use crate::ast_node_ref::AstNodeRef; use crate::node_key::NodeKey; -use crate::semantic_index::ast_ids::AstIdsBuilder; use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey; +use crate::semantic_index::ast_ids::{AstIdsBuilder, ScopedUseId}; use crate::semantic_index::definition::{ AnnotatedAssignmentDefinitionNodeRef, AssignmentDefinitionNodeRef, - ComprehensionDefinitionNodeRef, Definition, DefinitionCategory, DefinitionNodeKey, - DefinitionNodeRef, Definitions, ExceptHandlerDefinitionNodeRef, ForStmtDefinitionNodeRef, - ImportDefinitionNodeRef, ImportFromDefinitionNodeRef, ImportFromSubmoduleDefinitionNodeRef, - MatchPatternDefinitionNodeRef, StarImportDefinitionNodeRef, WithItemDefinitionNodeRef, + ComprehensionDefinitionNodeRef, Definition, DefinitionCategory, DefinitionKind, + DefinitionNodeKey, DefinitionNodeRef, Definitions, ExceptHandlerDefinitionNodeRef, + ForStmtDefinitionNodeRef, ImportDefinitionNodeRef, ImportFromDefinitionNodeRef, + ImportFromSubmoduleDefinitionNodeRef, LoopHeaderDefinitionKind, MatchPatternDefinitionNodeRef, + StarImportDefinitionNodeRef, WithItemDefinitionNodeRef, }; use crate::semantic_index::expression::{Expression, ExpressionKind}; use crate::semantic_index::place::{PlaceExpr, PlaceTableBuilder, ScopedPlaceId}; @@ -43,6 +44,7 @@ }; use crate::semantic_index::scope::{Scope, ScopeId, ScopeKind, ScopeLaziness}; use crate::semantic_index::symbol::{ScopedSymbolId, Symbol}; +use crate::semantic_index::use_def::Bindings; use crate::semantic_index::use_def::{ EnclosingSnapshotKey, FlowSnapshot, ScopedEnclosingSnapshotId, UseDefMapBuilder, }; @@ -53,22 +55,35 @@ mod except_handlers; +#[derive(Clone, Debug)] +struct LoopUse { + place: ScopedPlaceId, + use_id: ScopedUseId, +} + #[derive(Clone, Debug, Default)] struct Loop { /// Flow states at each `break` in the current loop. break_states: Vec<FlowSnapshot>, + uses: Vec<LoopUse>, + defined_places: FxHashSet<ScopedPlaceId>, } impl Loop { fn push_break(&mut self, state: FlowSnapshot) { self.break_states.push(state); } + + fn record_definition(&mut self, place: ScopedPlaceId) { + self.defined_places.insert(place); + } } struct ScopeInfo { file_scope_id: FileScopeId, /// Current loop state; None if we are not currently visiting a loop current_loop: Option<Loop>, + condition_place_uses: Option<FxHashSet<ScopedPlaceId>>, } pub(super) struct SemanticIndexBuilder<'db, 'ast> { @@ -109,6 +124,7 @@ scopes_by_expression: ExpressionsScopeMapBuilder, definitions_by_node: FxHashMap<DefinitionNodeKey, Definitions<'db>>, expressions_by_node: FxHashMap<ExpressionNodeKey, Expression<'db>>, + loop_header_definitions: FxHashMap<Definition<'db>, Vec<Definition<'db>>>, imported_modules: FxHashSet<ModuleName>, seen_submodule_imports: FxHashSet<String>, /// Hashset of all [`FileScopeId`]s that correspond to [generator functions]. @@ -147,6 +163,7 @@ scopes_by_node: FxHashMap::default(), definitions_by_node: FxHashMap::default(), expressions_by_node: FxHashMap::default(), + loop_header_definitions: FxHashMap::default(), seen_submodule_imports: FxHashSet::default(), imported_modules: FxHashSet::default(), @@ -256,8 +273,17 @@ /// Pop a loop, replacing with the previous saved outer loop, if any. fn pop_loop(&mut self, outer_loop: Option<Loop>) -> Loop { - std::mem::replace(&mut self.current_scope_info_mut().current_loop, outer_loop) - .expect("pop_loop() should not be called without a prior push_loop()") + let inner_loop = std::mem::take(&mut self.current_scope_info_mut().current_loop) + .expect("pop_loop() should not be called without a prior push_loop()"); + let merged_outer = outer_loop.map(|mut outer| { + outer.uses.extend(inner_loop.uses.iter().cloned()); + outer + .defined_places + .extend(inner_loop.defined_places.iter().copied()); + outer + }); + self.current_scope_info_mut().current_loop = merged_outer; + inner_loop } fn current_loop_mut(&mut self) -> Option<&mut Loop> { @@ -308,6 +334,7 @@ self.scope_stack.push(ScopeInfo { file_scope_id, current_loop: None, + condition_place_uses: None, }); } @@ -656,6 +683,9 @@ place: ScopedPlaceId, definition_node: impl Into<DefinitionNodeRef<'ast, 'db>> + std::fmt::Debug + Copy, ) -> Definition<'db> { + if let Some(current_loop) = self.current_loop_mut() { + current_loop.record_definition(place); + } let (definition, num_definitions) = self.push_additional_definition(place, definition_node); debug_assert_eq!( num_definitions, 1, @@ -755,6 +785,40 @@ (definition, num_definitions) } + fn add_loop_header_definition( + &mut self, + place: ScopedPlaceId, + loop_node: &'ast ast::StmtWhile, + definitions: Vec<Definition<'db>>, + seed_definitions: Vec<Definition<'db>>, + bindings: Bindings, + seed_bindings: Bindings, + ) -> Definition<'db> { + let kind = DefinitionKind::LoopHeader(LoopHeaderDefinitionKind::new( + AstNodeRef::new(self.module, loop_node), + definitions, + seed_definitions, + bindings, + seed_bindings, + )); + let is_reexported = kind.is_reexported(); + let definition = Definition::new( + self.db, + self.file, + self.current_scope(), + place, + kind, + is_reexported, + ); + + self.add_entry_for_definition_key(DefinitionNodeKey::from_node_key(NodeKey::from_node( + loop_node, + ))) + .push(definition); + + definition + } + fn record_expression_narrowing_constraint( &mut self, predicate_node: &ast::Expr, @@ -1318,6 +1382,7 @@ scopes: self.scopes, definitions_by_node: self.definitions_by_node, expressions_by_node: self.expressions_by_node, + loop_header_definitions: self.loop_header_definitions, scope_ids_by_scope: self.scope_ids_by_scope, ast_ids, scopes_by_expression: self.scopes_by_expression.build(), @@ -1341,6 +1406,96 @@ self.source_text .get_or_init(|| source_text(self.db, self.file)) } + + fn record_use(&mut self, place_id: ScopedPlaceId, expr_node_key: ExpressionNodeKey) { + if let ScopedPlaceId::Symbol(symbol_id) = place_id { + self.mark_symbol_used(symbol_id); + } + let use_id = self.current_ast_ids().record_use(expr_node_key); + self.current_use_def_map_mut() + .record_use(place_id, use_id, expr_node_key.into()); + if let Some(condition_place_uses) = &mut self.current_scope_info_mut().condition_place_uses + { + condition_place_uses.insert(place_id); + } + if let Some(current_loop) = self.current_loop_mut() { + current_loop.uses.push(LoopUse { + place: place_id, + use_id, + }); + } + } + + fn create_loop_header_definitions( + &mut self, + loop_node: &'ast ast::StmtWhile, + loop_state: &Loop, + pre_loop: &FlowSnapshot, + post_body: &FlowSnapshot, + ) { + let mut used_places = FxHashSet::default(); + for loop_use in &loop_state.uses { + used_places.insert(loop_use.place); + } + + let scope_id = self.current_scope(); + for place in loop_state.defined_places.iter() { + if !used_places.contains(place) { + continue; + } + + let pre_loop_binding_ids = pre_loop.binding_ids_for_place_excluding_unbound(*place); + let seed_bindings = pre_loop.bindings_for_place(*place); + let loop_bindings = self + .current_use_def_map_mut() + .merge_bindings(seed_bindings.clone(), post_body.bindings_for_place(*place)); + let mut seed_definitions = self + .current_use_def_map() + .definitions_for_place_in_snapshot(pre_loop, *place); + let mut definitions = seed_definitions.clone(); + definitions.extend( + self.current_use_def_map() + .definitions_for_place_in_snapshot(post_body, *place), + ); + definitions.sort(); + definitions.dedup(); + seed_definitions.sort(); + seed_definitions.dedup(); + + if definitions.is_empty() { + continue; + } + + let header_definition = self.add_loop_header_definition( + *place, + loop_node, + definitions.clone(), + seed_definitions, + loop_bindings, + seed_bindings, + ); + let header_definition_id = self.use_def_maps[scope_id] + .register_definition_with_bindings( + header_definition, + pre_loop.bindings_for_place(*place), + pre_loop.declarations_for_place(*place), + ); + self.loop_header_definitions + .insert(header_definition, definitions); + + if pre_loop_binding_ids.is_empty() { + continue; + } + + for loop_use in loop_state.uses.iter().filter(|use_| use_.place == *place) { + self.current_use_def_map_mut().replace_use_bindings( + loop_use.use_id, + &pre_loop_binding_ids, + header_definition_id, + ); + } + } + } } impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { @@ -1924,41 +2079,66 @@ self.in_type_checking_block = is_outer_block_in_type_checking; } - ast::Stmt::While(ast::StmtWhile { - test, - body, - orelse, - range: _, - node_index: _, - }) => { + ast::Stmt::While(stmt_while) => { + let ast::StmtWhile { + test, + body, + orelse, + range: _, + node_index: _, + } = stmt_while; + self.current_scope_info_mut() + .condition_place_uses + .replace(FxHashSet::default()); + let outer_loop = self.push_loop(); + let has_outer_loop = outer_loop.is_some(); self.visit_expr(test); + let condition_place_uses = self + .current_scope_info_mut() + .condition_place_uses + .take() + .unwrap(); let pre_loop = self.flow_snapshot(); - let predicate = self.record_expression_narrowing_constraint(test); - self.record_reachability_constraint(predicate); + let predicate = self.build_predicate(test); + let predicate_id = self.add_predicate(predicate); + self.record_narrowing_constraint_id(predicate_id); + self.record_reachability_constraint_id(predicate_id); - let outer_loop = self.push_loop(); self.visit_body(body); let this_loop = self.pop_loop(outer_loop); + let post_body = self.flow_snapshot(); + if !has_outer_loop { + self.create_loop_header_definitions( + stmt_while, &this_loop, &pre_loop, &post_body, + ); + } // We execute the `else` branch once the condition evaluates to false. This could // happen without ever executing the body, if the condition is false the first time // it's tested. Or it could happen if a _later_ evaluation of the condition yields // false. So we merge in the pre-loop state here into the post-body state: - self.flow_merge(pre_loop); + self.flow_merge(pre_loop.clone()); // The `else` branch can only be reached if the loop condition *can* be false. To // model this correctly, we need a second copy of the while condition constraint, // since the first and later evaluations might produce different results. We would // otherwise simplify `predicate AND ~predicate` to `False`. - let later_predicate_id = self.current_use_def_map_mut().add_predicate(predicate); - let later_reachability_constraint = self - .current_reachability_constraints_mut() - .add_atom(later_predicate_id); - self.record_negated_reachability_constraint(later_reachability_constraint); - - self.record_negated_narrowing_constraint(predicate); + let condition_depends_on_loop = condition_place_uses + .iter() + .any(|place| pre_loop.has_new_bindings_for_place(&post_body, *place)); + if condition_depends_on_loop { + self.record_ambiguous_reachability(); + } else { + let later_predicate_id = + self.current_use_def_map_mut().add_predicate(predicate); + let later_reachability_constraint = self + .current_reachability_constraints_mut() + .add_atom(later_predicate_id); + self.record_negated_reachability_constraint(later_reachability_constraint); + self.record_negated_narrowing_constraint(predicate); + } self.visit_body(orelse); @@ -2469,12 +2649,7 @@ let place_id = self.add_place(place_expr); if is_use { - if let ScopedPlaceId::Symbol(symbol_id) = place_id { - self.mark_symbol_used(symbol_id); - } - let use_id = self.current_ast_ids().record_use(expr); - self.current_use_def_map_mut() - .record_use(place_id, use_id, node_key); + self.record_use(place_id, expr.into()); } if is_definition {
diff --git a/crates/ty_python_semantic/src/semantic_index/definition.rs b/crates/ty_python_semantic/src/semantic_index/definition.rs index d156c3a..eebb463 100644 --- a/crates/ty_python_semantic/src/semantic_index/definition.rs +++ b/crates/ty_python_semantic/src/semantic_index/definition.rs
@@ -13,6 +13,7 @@ use crate::semantic_index::place::ScopedPlaceId; use crate::semantic_index::scope::{FileScopeId, ScopeId}; use crate::semantic_index::symbol::ScopedSymbolId; +use crate::semantic_index::use_def::Bindings; use crate::unpack::{Unpack, UnpackPosition}; /// A definition of a place. @@ -753,6 +754,7 @@ TypeVar(AstNodeRef<ast::TypeParamTypeVar>), ParamSpec(AstNodeRef<ast::TypeParamParamSpec>), TypeVarTuple(AstNodeRef<ast::TypeParamTypeVarTuple>), + LoopHeader(LoopHeaderDefinitionKind<'db>), } impl DefinitionKind<'_> { @@ -835,6 +837,7 @@ DefinitionKind::TypeVarTuple(type_var_tuple) => { type_var_tuple.node(module).name.range() } + DefinitionKind::LoopHeader(loop_header) => loop_header.node(module).range(), } } @@ -880,6 +883,7 @@ DefinitionKind::TypeVar(type_var) => type_var.node(module).range(), DefinitionKind::ParamSpec(param_spec) => param_spec.node(module).range(), DefinitionKind::TypeVarTuple(type_var_tuple) => type_var_tuple.node(module).range(), + DefinitionKind::LoopHeader(loop_header) => loop_header.node(module).range(), } } @@ -935,7 +939,8 @@ | DefinitionKind::WithItem(_) | DefinitionKind::MatchPattern(_) | DefinitionKind::ImportFromSubmodule(_) - | DefinitionKind::ExceptHandler(_) => DefinitionCategory::Binding, + | DefinitionKind::ExceptHandler(_) + | DefinitionKind::LoopHeader(_) => DefinitionCategory::Binding, } } } @@ -1211,9 +1216,62 @@ } } +#[derive(Clone, Debug, get_size2::GetSize)] +pub struct LoopHeaderDefinitionKind<'db> { + node: AstNodeRef<ast::StmtWhile>, + definitions: Vec<Definition<'db>>, + seed_definitions: Vec<Definition<'db>>, + bindings: Bindings, + seed_bindings: Bindings, +} + +impl<'db> LoopHeaderDefinitionKind<'db> { + pub(crate) fn new( + node: AstNodeRef<ast::StmtWhile>, + definitions: Vec<Definition<'db>>, + seed_definitions: Vec<Definition<'db>>, + bindings: Bindings, + seed_bindings: Bindings, + ) -> Self { + Self { + node, + definitions, + seed_definitions, + bindings, + seed_bindings, + } + } + + pub(crate) fn node<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::StmtWhile { + self.node.node(module) + } + + pub(crate) fn definitions(&self) -> &[Definition<'db>] { + &self.definitions + } + + pub(crate) fn seed_definitions(&self) -> &[Definition<'db>] { + &self.seed_definitions + } + + pub(crate) fn bindings(&self) -> &Bindings { + &self.bindings + } + + pub(crate) fn seed_bindings(&self) -> &Bindings { + &self.seed_bindings + } +} + #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, salsa::Update, get_size2::GetSize)] pub(crate) struct DefinitionNodeKey(NodeKey); +impl DefinitionNodeKey { + pub(crate) fn from_node_key(node_key: NodeKey) -> Self { + Self(node_key) + } +} + impl From<&ast::Alias> for DefinitionNodeKey { fn from(node: &ast::Alias) -> Self { Self(NodeKey::from_node(node)) @@ -1280,6 +1338,12 @@ } } +impl From<&ast::StmtWhile> for DefinitionNodeKey { + fn from(node: &ast::StmtWhile) -> Self { + Self(NodeKey::from_node(node)) + } +} + impl From<&ast::Parameter> for DefinitionNodeKey { fn from(node: &ast::Parameter) -> Self { Self(NodeKey::from_node(node))
diff --git a/crates/ty_python_semantic/src/semantic_index/narrowing_constraints.rs b/crates/ty_python_semantic/src/semantic_index/narrowing_constraints.rs index 8d27dd1..2f7ca8f 100644 --- a/crates/ty_python_semantic/src/semantic_index/narrowing_constraints.rs +++ b/crates/ty_python_semantic/src/semantic_index/narrowing_constraints.rs
@@ -111,6 +111,13 @@ ) -> ScopedNarrowingConstraint { self.lists.intersect(a, b) } + + pub(crate) fn iter_predicates( + &self, + set: ScopedNarrowingConstraint, + ) -> NarrowingConstraintsIterator<'_> { + self.lists.iter_set_reverse(set).copied() + } } // Iteration @@ -142,12 +149,5 @@ } } - impl NarrowingConstraintsBuilder { - pub(crate) fn iter_predicates( - &self, - set: ScopedNarrowingConstraint, - ) -> NarrowingConstraintsIterator<'_> { - self.lists.iter_set_reverse(set).copied() - } - } + // Test-only impl removed; use the main impl above. }
diff --git a/crates/ty_python_semantic/src/semantic_index/use_def.rs b/crates/ty_python_semantic/src/semantic_index/use_def.rs index c349281..931884c 100644 --- a/crates/ty_python_semantic/src/semantic_index/use_def.rs +++ b/crates/ty_python_semantic/src/semantic_index/use_def.rs
@@ -261,8 +261,9 @@ }; use crate::semantic_index::scope::{FileScopeId, ScopeKind, ScopeLaziness}; use crate::semantic_index::symbol::ScopedSymbolId; +pub(crate) use crate::semantic_index::use_def::place_state::Bindings; use crate::semantic_index::use_def::place_state::{ - Bindings, Declarations, EnclosingSnapshot, LiveBindingsIterator, LiveDeclaration, + Declarations, EnclosingSnapshot, LiveBindingsIterator, LiveDeclaration, LiveDeclarationsIterator, PlaceState, PreviousDefinitions, ScopedDefinitionId, }; use crate::semantic_index::{EnclosingSnapshotResult, SemanticIndex}; @@ -364,6 +365,13 @@ ) } + pub(crate) fn bindings_from_snapshot<'a>( + &'a self, + bindings: &'a Bindings, + ) -> BindingWithConstraintsIterator<'a, 'db> { + self.bindings_iterator(bindings, BoundnessAnalysis::BasedOnUnboundVisibility) + } + pub(crate) fn applicable_constraints( &self, constraint_key: ConstraintKey, @@ -825,6 +833,52 @@ reachability: ScopedReachabilityConstraintId, } +impl FlowSnapshot { + pub(super) fn has_new_bindings_for_place( + &self, + other: &FlowSnapshot, + place: ScopedPlaceId, + ) -> bool { + let (self_ids, other_ids) = match place { + ScopedPlaceId::Symbol(symbol) => ( + self.symbol_states[symbol].bindings().binding_ids(), + other.symbol_states[symbol].bindings().binding_ids(), + ), + ScopedPlaceId::Member(member) => ( + self.member_states[member].bindings().binding_ids(), + other.member_states[member].bindings().binding_ids(), + ), + }; + other_ids.iter().any(|id| !self_ids.contains(id)) + } + + pub(super) fn bindings_for_place(&self, place: ScopedPlaceId) -> Bindings { + match place { + ScopedPlaceId::Symbol(symbol) => self.symbol_states[symbol].bindings().clone(), + ScopedPlaceId::Member(member) => self.member_states[member].bindings().clone(), + } + } + + pub(super) fn binding_ids_for_place_excluding_unbound( + &self, + place: ScopedPlaceId, + ) -> Vec<ScopedDefinitionId> { + let mut ids = match place { + ScopedPlaceId::Symbol(symbol) => self.symbol_states[symbol].bindings().binding_ids(), + ScopedPlaceId::Member(member) => self.member_states[member].bindings().binding_ids(), + }; + ids.retain(|id| *id != ScopedDefinitionId::UNBOUND); + ids + } + + pub(super) fn declarations_for_place(&self, place: ScopedPlaceId) -> Declarations { + match place { + ScopedPlaceId::Symbol(symbol) => self.symbol_states[symbol].declarations().clone(), + ScopedPlaceId::Member(member) => self.member_states[member].declarations().clone(), + } + } +} + /// A snapshot of the state of a single symbol (e.g. `obj`) and all of its associated members /// (e.g. `obj.attr`, `obj["key"]`). pub(super) struct SingleSymbolSnapshot { @@ -1248,6 +1302,107 @@ self.record_node_reachability(node_key); } + pub(super) fn merge_use_with_snapshot( + &mut self, + use_id: ScopedUseId, + snapshot: &FlowSnapshot, + place: ScopedPlaceId, + predicate: Option<ScopedPredicateId>, + ) { + let mut bindings = self.bindings_by_use[use_id].clone(); + let mut backedge_bindings = snapshot.bindings_for_place(place); + if let Some(predicate) = predicate { + if predicate != ScopedPredicateId::ALWAYS_TRUE + && predicate != ScopedPredicateId::ALWAYS_FALSE + { + backedge_bindings + .record_narrowing_constraint(&mut self.narrowing_constraints, predicate.into()); + } + } + bindings.merge( + backedge_bindings, + &mut self.narrowing_constraints, + &mut self.reachability_constraints, + ); + self.bindings_by_use[use_id] = bindings; + } + + pub(super) fn register_definition( + &mut self, + definition: Definition<'db>, + ) -> ScopedDefinitionId { + self.all_definitions + .push(DefinitionState::Defined(definition)) + } + + pub(super) fn register_definition_with_bindings( + &mut self, + definition: Definition<'db>, + bindings: Bindings, + declarations: Declarations, + ) -> ScopedDefinitionId { + let def_id = self + .all_definitions + .push(DefinitionState::Defined(definition)); + self.bindings_by_definition.insert(definition, bindings); + self.declarations_by_binding + .insert(definition, declarations); + def_id + } + + pub(super) fn replace_use_with_definition( + &mut self, + use_id: ScopedUseId, + definition_id: ScopedDefinitionId, + ) { + self.bindings_by_use[use_id].replace_definition(definition_id); + } + + pub(super) fn replace_use_bindings( + &mut self, + use_id: ScopedUseId, + from: &[ScopedDefinitionId], + definition_id: ScopedDefinitionId, + ) { + self.bindings_by_use[use_id].replace_definitions( + from, + definition_id, + &mut self.narrowing_constraints, + &mut self.reachability_constraints, + ); + } + + pub(super) fn merge_bindings(&mut self, mut bindings: Bindings, other: Bindings) -> Bindings { + bindings.merge( + other, + &mut self.narrowing_constraints, + &mut self.reachability_constraints, + ); + bindings + } + + pub(super) fn definitions_for_place_in_snapshot( + &self, + snapshot: &FlowSnapshot, + place: ScopedPlaceId, + ) -> Vec<Definition<'db>> { + let binding_ids = match place { + ScopedPlaceId::Symbol(symbol) => { + snapshot.symbol_states[symbol].bindings().binding_ids() + } + ScopedPlaceId::Member(member) => { + snapshot.member_states[member].bindings().binding_ids() + } + }; + binding_ids + .into_iter() + .filter_map(|binding_id| match self.all_definitions[binding_id] { + DefinitionState::Defined(definition) => Some(definition), + DefinitionState::Undefined | DefinitionState::Deleted => None, + }) + .collect() + } + pub(super) fn record_node_reachability(&mut self, node_key: NodeKey) { self.node_reachability.insert(node_key, self.reachability); }
diff --git a/crates/ty_python_semantic/src/semantic_index/use_def/place_state.rs b/crates/ty_python_semantic/src/semantic_index/use_def/place_state.rs index 4695bda..3b992ad 100644 --- a/crates/ty_python_semantic/src/semantic_index/use_def/place_state.rs +++ b/crates/ty_python_semantic/src/semantic_index/use_def/place_state.rs
@@ -56,7 +56,7 @@ /// A newtype-index for a definition in a particular scope. #[newtype_index] #[derive(Ord, PartialOrd, get_size2::GetSize)] -pub(super) struct ScopedDefinitionId; +pub(crate) struct ScopedDefinitionId; impl ScopedDefinitionId { /// A special ID that is used to describe an implicit start-of-scope state. When @@ -74,7 +74,7 @@ /// Live declarations for a single place at some point in control flow, with their /// corresponding reachability constraints. #[derive(Clone, Debug, Default, PartialEq, Eq, salsa::Update, get_size2::GetSize)] -pub(super) struct Declarations { +pub(crate) struct Declarations { /// A list of live declarations for this place, sorted by their `ScopedDefinitionId` live_declarations: SmallVec<[LiveDeclaration; 2]>, } @@ -206,7 +206,7 @@ /// Live bindings for a single place at some point in control flow. Each live binding comes /// with a set of narrowing constraints and a reachability constraint. #[derive(Clone, Debug, Default, PartialEq, Eq, salsa::Update, get_size2::GetSize)] -pub(super) struct Bindings { +pub(crate) struct Bindings { /// The narrowing constraint applicable to the "unbound" binding, if we need access to it even /// when it's not visible. This happens in class scopes, where local name bindings are not visible /// to nested scopes, but we still need to know what narrowing constraints were applied to the @@ -222,12 +222,109 @@ .unwrap_or(self.live_bindings[0].narrowing_constraint) } + pub(super) fn has_defined_bindings(&self) -> bool { + self.live_bindings + .iter() + .any(|binding| !binding.binding.is_unbound()) + } + + pub(super) fn from_single( + binding: ScopedDefinitionId, + narrowing_constraint: ScopedNarrowingConstraint, + reachability_constraint: ScopedReachabilityConstraintId, + ) -> Self { + Self { + unbound_narrowing_constraint: None, + live_bindings: smallvec![LiveBinding { + binding, + narrowing_constraint, + reachability_constraint, + }], + } + } + + pub(super) fn representative_narrowing_constraint(&self) -> ScopedNarrowingConstraint { + self.live_bindings + .first() + .map(|binding| binding.narrowing_constraint) + .unwrap_or_else(ScopedNarrowingConstraint::empty) + } + + pub(super) fn retain_bindings( + &mut self, + mut predicate: impl FnMut(ScopedDefinitionId) -> bool, + ) { + self.live_bindings + .retain(|binding| predicate(binding.binding)); + } + + pub(super) fn binding_ids(&self) -> Vec<ScopedDefinitionId> { + self.live_bindings + .iter() + .map(|binding| binding.binding) + .collect() + } + pub(super) fn finish(&mut self, reachability_constraints: &mut ReachabilityConstraintsBuilder) { self.live_bindings.shrink_to_fit(); for binding in &self.live_bindings { reachability_constraints.mark_used(binding.reachability_constraint); } } + + pub(super) fn replace_definition(&mut self, binding: ScopedDefinitionId) { + for live_binding in &mut self.live_bindings { + live_binding.binding = binding; + } + self.live_bindings + .sort_by(|left, right| left.binding.cmp(&right.binding)); + } + + pub(super) fn replace_definitions( + &mut self, + from: &[ScopedDefinitionId], + replacement: ScopedDefinitionId, + narrowing_constraints: &mut NarrowingConstraintsBuilder, + reachability_constraints: &mut ReachabilityConstraintsBuilder, + ) { + if from.is_empty() { + return; + } + + let mut changed = false; + for live_binding in &mut self.live_bindings { + if from.contains(&live_binding.binding) { + live_binding.binding = replacement; + changed = true; + } + } + + if !changed { + return; + } + + self.live_bindings + .sort_by(|left, right| left.binding.cmp(&right.binding)); + + let mut merged: SmallVec<[LiveBinding; 2]> = SmallVec::new(); + for binding in std::mem::take(&mut self.live_bindings) { + match merged.last_mut() { + Some(last) if last.binding == binding.binding => { + last.narrowing_constraint = narrowing_constraints.intersect_constraints( + last.narrowing_constraint, + binding.narrowing_constraint, + ); + last.reachability_constraint = reachability_constraints.add_or_constraint( + last.reachability_constraint, + binding.reachability_constraint, + ); + } + _ => merged.push(binding), + } + } + + self.live_bindings = merged; + } } /// One of the live bindings for a single place at some point in control flow. @@ -291,6 +388,26 @@ } } + pub(super) fn add_narrowing_constraint_from( + &mut self, + narrowing_constraints: &mut NarrowingConstraintsBuilder, + constraint: ScopedNarrowingConstraint, + ) { + let predicates: Vec<_> = narrowing_constraints.iter_predicates(constraint).collect(); + let mut unbound_constraint = self.unbound_narrowing_constraint; + for predicate in predicates { + if let Some(existing) = unbound_constraint { + unbound_constraint = + Some(narrowing_constraints.add_predicate_to_constraint(existing, predicate)); + } + for binding in &mut self.live_bindings { + binding.narrowing_constraint = narrowing_constraints + .add_predicate_to_constraint(binding.narrowing_constraint, predicate); + } + } + self.unbound_narrowing_constraint = unbound_constraint; + } + /// Add given reachability constraint to all live bindings. pub(super) fn record_reachability_constraint( &mut self,
diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 76cc468..6d2b335 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs
@@ -18,6 +18,8 @@ use ruff_python_ast as ast; use ruff_python_ast::name::Name; use ruff_text_size::{Ranged, TextRange}; +use rustc_hash::FxHashMap; +use salsa::plumbing::{AsId, Id}; use smallvec::{SmallVec, smallvec}; use ty_module_resolver::{KnownModule, Module, ModuleName, resolve_module}; @@ -158,8 +160,59 @@ diagnostics } +thread_local! { + static LOOP_HEADER_OVERRIDE: RefCell<FxHashMap<Id, Type<'static>>> = + RefCell::new(FxHashMap::default()); +} + +fn loop_header_override<'db>(definition: Definition<'db>) -> Option<Type<'db>> { + let id = definition.as_id(); + LOOP_HEADER_OVERRIDE.with(|cell| cell.borrow().get(&id).copied().map(restore_type_lifetime)) +} + +pub(crate) fn with_loop_header_override<'db, R>( + definition: Definition<'db>, + ty: Type<'db>, + f: impl FnOnce() -> R, +) -> R { + let id = definition.as_id(); + LOOP_HEADER_OVERRIDE.with(|cell| { + let mut overrides = cell.borrow_mut(); + let previous = overrides.insert(id, erase_type_lifetime(ty)); + drop(overrides); + + let result = f(); + + let mut overrides = cell.borrow_mut(); + match previous { + Some(previous) => { + overrides.insert(id, previous); + } + None => { + overrides.remove(&id); + } + } + result + }) +} + +fn erase_type_lifetime<'db>(ty: Type<'db>) -> Type<'static> { + // SAFETY: `Type` is a copyable, db-backed handle; we only use this within + // a single thread to provide a temporary loop-header override. + unsafe { std::mem::transmute::<Type<'db>, Type<'static>>(ty) } +} + +fn restore_type_lifetime<'db>(ty: Type<'static>) -> Type<'db> { + // SAFETY: This is the inverse of `erase_type_lifetime` and is only used + // within the dynamic scope of a loop-header override. + unsafe { std::mem::transmute::<Type<'static>, Type<'db>>(ty) } +} + /// Infer the type of a binding. pub(crate) fn binding_type<'db>(db: &'db dyn Db, definition: Definition<'db>) -> Type<'db> { + if let Some(override_type) = loop_header_override(definition) { + return override_type; + } let inference = infer_definition_types(db, definition); inference.binding_type(definition) } @@ -11792,7 +11845,9 @@ elements .into_iter() .fold( - UnionBuilder::new(db).cycle_recovery(true), + UnionBuilder::new(db) + .cycle_recovery(true) + .recursively_defined(RecursivelyDefined::Yes), |builder, element| builder.add(element.into()), ) .build()
diff --git a/crates/ty_python_semantic/src/types/ide_support.rs b/crates/ty_python_semantic/src/types/ide_support.rs index 70fa611..7aacb05 100644 --- a/crates/ty_python_semantic/src/types/ide_support.rs +++ b/crates/ty_python_semantic/src/types/ide_support.rs
@@ -1377,7 +1377,8 @@ | DefinitionKind::ExceptHandler(_) | DefinitionKind::TypeVar(_) | DefinitionKind::ParamSpec(_) - | DefinitionKind::TypeVarTuple(_) => { + | DefinitionKind::TypeVarTuple(_) + | DefinitionKind::LoopHeader(_) => { // Not yet implemented return Err(()); }
diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 4332542..4aa4ddc 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs
@@ -39,7 +39,7 @@ use crate::semantic_index::definition::{ AnnotatedAssignmentDefinitionKind, AssignmentDefinitionKind, ComprehensionDefinitionKind, Definition, DefinitionKind, DefinitionNodeKey, DefinitionState, ExceptHandlerDefinitionKind, - ForStmtDefinitionKind, TargetKind, WithItemDefinitionKind, + ForStmtDefinitionKind, LoopHeaderDefinitionKind, TargetKind, WithItemDefinitionKind, }; use crate::semantic_index::expression::{Expression, ExpressionKind}; use crate::semantic_index::narrowing_constraints::ConstraintKey; @@ -121,7 +121,7 @@ TypeQualifiers, TypeVarBoundOrConstraints, TypeVarBoundOrConstraintsEvaluation, TypeVarDefaultEvaluation, TypeVarIdentity, TypeVarInstance, TypeVarKind, TypeVarVariance, TypedDictType, UnionBuilder, UnionType, UnionTypeInstance, binding_type, infer_scope_types, - todo_type, + todo_type, with_loop_header_override, }; use crate::types::{CallableTypes, overrides}; use crate::types::{ClassBase, add_inferred_python_version_hint_to_diagnostic}; @@ -1501,6 +1501,9 @@ DefinitionKind::TypeVarTuple(node) => { self.infer_typevartuple_definition(node.node(self.module()), definition); } + DefinitionKind::LoopHeader(loop_header) => { + self.infer_loop_header_definition(&loop_header, definition); + } } } @@ -1530,6 +1533,7 @@ DefinitionKind::Assignment(assignment) => { self.infer_assignment_deferred(assignment.value(self.module())); } + DefinitionKind::LoopHeader(_) => {} _ => {} } } @@ -5384,6 +5388,66 @@ add.insert(self, target_ty); } + fn infer_loop_header_definition( + &mut self, + loop_header: &LoopHeaderDefinitionKind<'db>, + definition: Definition<'db>, + ) { + const MAX_LOOP_HEADER_ITERATIONS: usize = 8; + let max_iterations = MAX_LOOP_HEADER_ITERATIONS; + let should_widen = true; + + let add = self.add_binding(loop_header.node(self.module()).into(), definition); + let mut seed_union = UnionBuilder::new(self.db()); + for seed_definition in loop_header.seed_definitions() { + seed_union = seed_union.add(binding_type(self.db(), *seed_definition)); + } + let mut current = if seed_union.is_empty() { + Type::unknown() + } else { + seed_union.build() + }; + + let mut changed = false; + for _ in 0..max_iterations { + let previous_multi = self.multi_inference_state; + self.multi_inference_state = MultiInferenceState::Overwrite; + let next = with_loop_header_override(definition, current, || { + let mut union = UnionBuilder::new(self.db()); + for definition in loop_header.definitions() { + let inferred = match definition.kind(self.db()) { + DefinitionKind::Assignment(assignment) => self + .infer_assignment_definition_impl( + &assignment, + *definition, + TypeContext::default(), + ), + DefinitionKind::AugmentedAssignment(augmented_assignment) => { + self.infer_augment_assignment(augmented_assignment.node(self.module())) + } + _ => binding_type(self.db(), *definition), + }; + union = union.add(inferred); + } + union.build() + }); + self.multi_inference_state = previous_multi; + + if next == current { + changed = false; + break; + } + changed = true; + current = next; + } + + if should_widen && changed { + current = KnownClass::Int.to_instance(self.db()); + } + + add.insert(self, current); + } + fn infer_assignment_definition_impl( &mut self, assignment: &AssignmentDefinitionKind<'db>,