Rollup merge of #64500 - nnethercote:ObligForest-fixups, r=nikomatsakis
Various `ObligationForest` improvements
These commits make the code both nicer and faster.
r? @nikomatsakis
diff --git a/src/librustc_data_structures/indexed_vec.rs b/src/librustc_data_structures/indexed_vec.rs
index 6f40d05..6e80b48 100644
--- a/src/librustc_data_structures/indexed_vec.rs
+++ b/src/librustc_data_structures/indexed_vec.rs
@@ -149,7 +149,7 @@
#[inline]
$v const unsafe fn from_u32_unchecked(value: u32) -> Self {
- unsafe { $type { private: value } }
+ $type { private: value }
}
/// Extracts the value of this index as an integer.
diff --git a/src/librustc_data_structures/obligation_forest/graphviz.rs b/src/librustc_data_structures/obligation_forest/graphviz.rs
index a0363e1..b2120b1 100644
--- a/src/librustc_data_structures/obligation_forest/graphviz.rs
+++ b/src/librustc_data_structures/obligation_forest/graphviz.rs
@@ -74,9 +74,9 @@
.flat_map(|i| {
let node = &self.nodes[i];
- node.parent.iter().map(|p| p.get())
- .chain(node.dependents.iter().map(|p| p.get()))
- .map(move |p| (p, i))
+ node.parent.iter()
+ .chain(node.dependents.iter())
+ .map(move |p| (p.index(), i))
})
.collect()
}
diff --git a/src/librustc_data_structures/obligation_forest/mod.rs b/src/librustc_data_structures/obligation_forest/mod.rs
index 6c52e62..189506b 100644
--- a/src/librustc_data_structures/obligation_forest/mod.rs
+++ b/src/librustc_data_structures/obligation_forest/mod.rs
@@ -9,7 +9,7 @@
//! `ObligationForest` supports two main public operations (there are a
//! few others not discussed here):
//!
-//! 1. Add a new root obligations (`push_tree`).
+//! 1. Add a new root obligations (`register_obligation`).
//! 2. Process the pending obligations (`process_obligations`).
//!
//! When a new obligation `N` is added, it becomes the root of an
@@ -20,13 +20,13 @@
//! with every pending obligation (so that will include `N`, the first
//! time). The callback also receives a (mutable) reference to the
//! per-tree state `T`. The callback should process the obligation `O`
-//! that it is given and return one of three results:
+//! that it is given and return a `ProcessResult`:
//!
-//! - `Ok(None)` -> ambiguous result. Obligation was neither a success
+//! - `Unchanged` -> ambiguous result. Obligation was neither a success
//! nor a failure. It is assumed that further attempts to process the
//! obligation will yield the same result unless something in the
//! surrounding environment changes.
-//! - `Ok(Some(C))` - the obligation was *shallowly successful*. The
+//! - `Changed(C)` - the obligation was *shallowly successful*. The
//! vector `C` is a list of subobligations. The meaning of this is that
//! `O` was successful on the assumption that all the obligations in `C`
//! are also successful. Therefore, `O` is only considered a "true"
@@ -34,7 +34,7 @@
//! state and the obligations in `C` become the new pending
//! obligations. They will be processed the next time you call
//! `process_obligations`.
-//! - `Err(E)` -> obligation failed with error `E`. We will collect this
+//! - `Error(E)` -> obligation failed with error `E`. We will collect this
//! error and return it from `process_obligations`, along with the
//! "backtrace" of obligations (that is, the list of obligations up to
//! and including the root of the failed obligation). No further
@@ -47,55 +47,50 @@
//! - `completed`: a list of obligations where processing was fully
//! completed without error (meaning that all transitive subobligations
//! have also been completed). So, for example, if the callback from
-//! `process_obligations` returns `Ok(Some(C))` for some obligation `O`,
+//! `process_obligations` returns `Changed(C)` for some obligation `O`,
//! then `O` will be considered completed right away if `C` is the
//! empty vector. Otherwise it will only be considered completed once
//! all the obligations in `C` have been found completed.
//! - `errors`: a list of errors that occurred and associated backtraces
//! at the time of error, which can be used to give context to the user.
//! - `stalled`: if true, then none of the existing obligations were
-//! *shallowly successful* (that is, no callback returned `Ok(Some(_))`).
+//! *shallowly successful* (that is, no callback returned `Changed(_)`).
//! This implies that all obligations were either errors or returned an
//! ambiguous result, which means that any further calls to
//! `process_obligations` would simply yield back further ambiguous
//! results. This is used by the `FulfillmentContext` to decide when it
//! has reached a steady state.
//!
-//! #### Snapshots
-//!
-//! The `ObligationForest` supports a limited form of snapshots; see
-//! `start_snapshot`, `commit_snapshot`, and `rollback_snapshot`. In
-//! particular, you can use a snapshot to roll back new root
-//! obligations. However, it is an error to attempt to
-//! `process_obligations` during a snapshot.
-//!
//! ### Implementation details
//!
//! For the most part, comments specific to the implementation are in the
//! code. This file only contains a very high-level overview. Basically,
//! the forest is stored in a vector. Each element of the vector is a node
-//! in some tree. Each node in the vector has the index of an (optional)
-//! parent and (for convenience) its root (which may be itself). It also
-//! has a current state, described by `NodeState`. After each
-//! processing step, we compress the vector to remove completed and error
-//! nodes, which aren't needed anymore.
+//! in some tree. Each node in the vector has the index of its dependents,
+//! including the first dependent which is known as the parent. It also
+//! has a current state, described by `NodeState`. After each processing
+//! step, we compress the vector to remove completed and error nodes, which
+//! aren't needed anymore.
use crate::fx::{FxHashMap, FxHashSet};
+use crate::indexed_vec::Idx;
+use crate::newtype_index;
-use std::cell::Cell;
+use std::cell::{Cell, RefCell};
use std::collections::hash_map::Entry;
use std::fmt::Debug;
use std::hash;
use std::marker::PhantomData;
-mod node_index;
-use self::node_index::NodeIndex;
-
mod graphviz;
#[cfg(test)]
mod tests;
+newtype_index! {
+ pub struct NodeIndex { .. }
+}
+
pub trait ForestObligation : Clone + Debug {
type Predicate : Clone + hash::Hash + Eq + Debug;
@@ -148,18 +143,22 @@
/// At the end of processing, those nodes will be removed by a
/// call to `compress`.
///
- /// At all times we maintain the invariant that every node appears
- /// at a higher index than its parent. This is needed by the
- /// backtrace iterator (which uses `split_at`).
+ /// Ideally, this would be an `IndexVec<NodeIndex, Node<O>>`. But that is
+ /// slower, because this vector is accessed so often that the
+ /// `u32`-to-`usize` conversions required for accesses are significant.
nodes: Vec<Node<O>>,
/// A cache of predicates that have been successfully completed.
done_cache: FxHashSet<O::Predicate>,
- /// An cache of the nodes in `nodes`, indexed by predicate.
+ /// A cache of the nodes in `nodes`, indexed by predicate. Unfortunately,
+ /// its contents are not guaranteed to match those of `nodes`. See the
+ /// comments in `process_obligation` for details.
waiting_cache: FxHashMap<O::Predicate, NodeIndex>,
- scratch: Option<Vec<usize>>,
+ /// A scratch vector reused in various operations, to avoid allocating new
+ /// vectors.
+ scratch: RefCell<Vec<usize>>,
obligation_tree_id_generator: ObligationTreeIdGenerator,
@@ -178,19 +177,41 @@
obligation: O,
state: Cell<NodeState>,
- /// The parent of a node - the original obligation of
- /// which it is a subobligation. Except for error reporting,
- /// it is just like any member of `dependents`.
+ /// The parent of a node - the original obligation of which it is a
+ /// subobligation. Except for error reporting, it is just like any member
+ /// of `dependents`.
+ ///
+ /// Unlike `ObligationForest::nodes`, this uses `NodeIndex` rather than
+ /// `usize` for the index, because keeping the size down is more important
+ /// than the cost of converting to a `usize` for indexing.
parent: Option<NodeIndex>,
- /// Obligations that depend on this obligation for their
- /// completion. They must all be in a non-pending state.
+ /// Obligations that depend on this obligation for their completion. They
+ /// must all be in a non-pending state.
+ ///
+ /// This uses `NodeIndex` for the same reason as `parent`.
dependents: Vec<NodeIndex>,
/// Identifier of the obligation tree to which this node belongs.
obligation_tree_id: ObligationTreeId,
}
+impl<O> Node<O> {
+ fn new(
+ parent: Option<NodeIndex>,
+ obligation: O,
+ obligation_tree_id: ObligationTreeId
+ ) -> Node<O> {
+ Node {
+ obligation,
+ state: Cell::new(NodeState::Pending),
+ parent,
+ dependents: vec![],
+ obligation_tree_id,
+ }
+ }
+}
+
/// The state of one node in some tree within the forest. This
/// represents the current state of processing for the obligation (of
/// type `O`) associated with this node.
@@ -262,7 +283,7 @@
nodes: vec![],
done_cache: Default::default(),
waiting_cache: Default::default(),
- scratch: Some(vec![]),
+ scratch: RefCell::new(vec![]),
obligation_tree_id_generator: (0..).map(ObligationTreeId),
error_cache: Default::default(),
}
@@ -275,14 +296,12 @@
}
/// Registers an obligation.
- ///
- /// This CAN be done in a snapshot
pub fn register_obligation(&mut self, obligation: O) {
// Ignore errors here - there is no guarantee of success.
let _ = self.register_obligation_at(obligation, None);
}
- // returns Err(()) if we already know this obligation failed.
+ // Returns Err(()) if we already know this obligation failed.
fn register_obligation_at(&mut self, obligation: O, parent: Option<NodeIndex>)
-> Result<(), ()>
{
@@ -294,15 +313,16 @@
Entry::Occupied(o) => {
debug!("register_obligation_at({:?}, {:?}) - duplicate of {:?}!",
obligation, parent, o.get());
- let node = &mut self.nodes[o.get().get()];
- if let Some(parent) = parent {
+ let node = &mut self.nodes[o.get().index()];
+ if let Some(parent_index) = parent {
// If the node is already in `waiting_cache`, it's already
// been marked with a parent. (It's possible that parent
// has been cleared by `apply_rewrites`, though.) So just
// dump `parent` into `node.dependents`... unless it's
// already in `node.dependents` or `node.parent`.
- if !node.dependents.contains(&parent) && Some(parent) != node.parent {
- node.dependents.push(parent);
+ if !node.dependents.contains(&parent_index) &&
+ Some(parent_index) != node.parent {
+ node.dependents.push(parent_index);
}
}
if let NodeState::Error = node.state.get() {
@@ -316,9 +336,8 @@
obligation, parent, self.nodes.len());
let obligation_tree_id = match parent {
- Some(p) => {
- let parent_node = &self.nodes[p.get()];
- parent_node.obligation_tree_id
+ Some(parent_index) => {
+ self.nodes[parent_index.index()].obligation_tree_id
}
None => self.obligation_tree_id_generator.next().unwrap()
};
@@ -342,13 +361,11 @@
}
/// Converts all remaining obligations to the given error.
- ///
- /// This cannot be done during a snapshot.
pub fn to_errors<E: Clone>(&mut self, error: E) -> Vec<Error<O, E>> {
let mut errors = vec![];
- for index in 0..self.nodes.len() {
- if let NodeState::Pending = self.nodes[index].state.get() {
- let backtrace = self.error_at(index);
+ for (i, node) in self.nodes.iter().enumerate() {
+ if let NodeState::Pending = node.state.get() {
+ let backtrace = self.error_at(i);
errors.push(Error {
error: error.clone(),
backtrace,
@@ -373,7 +390,6 @@
fn insert_into_error_cache(&mut self, node_index: usize) {
let node = &self.nodes[node_index];
-
self.error_cache
.entry(node.obligation_tree_id)
.or_default()
@@ -393,16 +409,22 @@
let mut errors = vec![];
let mut stalled = true;
- for index in 0..self.nodes.len() {
- debug!("process_obligations: node {} == {:?}", index, self.nodes[index]);
+ for i in 0..self.nodes.len() {
+ let node = &mut self.nodes[i];
- let result = match self.nodes[index] {
- Node { ref state, ref mut obligation, .. } if state.get() == NodeState::Pending =>
- processor.process_obligation(obligation),
+ debug!("process_obligations: node {} == {:?}", i, node);
+
+ // `processor.process_obligation` can modify the predicate within
+ // `node.obligation`, and that predicate is the key used for
+ // `self.waiting_cache`. This means that `self.waiting_cache` can
+ // get out of sync with `nodes`. It's not very common, but it does
+ // happen, and code in `compress` has to allow for it.
+ let result = match node.state.get() {
+ NodeState::Pending => processor.process_obligation(&mut node.obligation),
_ => continue
};
- debug!("process_obligations: node {} got result {:?}", index, result);
+ debug!("process_obligations: node {} got result {:?}", i, result);
match result {
ProcessResult::Unchanged => {
@@ -411,23 +433,23 @@
ProcessResult::Changed(children) => {
// We are not (yet) stalled.
stalled = false;
- self.nodes[index].state.set(NodeState::Success);
+ node.state.set(NodeState::Success);
for child in children {
let st = self.register_obligation_at(
child,
- Some(NodeIndex::new(index))
+ Some(NodeIndex::new(i))
);
if let Err(()) = st {
- // error already reported - propagate it
+ // Error already reported - propagate it
// to our node.
- self.error_at(index);
+ self.error_at(i);
}
}
}
ProcessResult::Error(err) => {
stalled = false;
- let backtrace = self.error_at(index);
+ let backtrace = self.error_at(i);
errors.push(Error {
error: err,
backtrace,
@@ -448,8 +470,6 @@
self.mark_as_waiting();
self.process_cycles(processor);
-
- // Now we have to compress the result
let completed = self.compress(do_completed);
debug!("process_obligations: complete");
@@ -465,97 +485,92 @@
/// report all cycles between them. This should be called
/// after `mark_as_waiting` marks all nodes with pending
/// subobligations as NodeState::Waiting.
- fn process_cycles<P>(&mut self, processor: &mut P)
+ fn process_cycles<P>(&self, processor: &mut P)
where P: ObligationProcessor<Obligation=O>
{
- let mut stack = self.scratch.take().unwrap();
+ let mut stack = self.scratch.replace(vec![]);
debug_assert!(stack.is_empty());
debug!("process_cycles()");
- for index in 0..self.nodes.len() {
+ for (i, node) in self.nodes.iter().enumerate() {
// For rustc-benchmarks/inflate-0.1.0 this state test is extremely
// hot and the state is almost always `Pending` or `Waiting`. It's
// a win to handle the no-op cases immediately to avoid the cost of
// the function call.
- let state = self.nodes[index].state.get();
- match state {
+ match node.state.get() {
NodeState::Waiting | NodeState::Pending | NodeState::Done | NodeState::Error => {},
- _ => self.find_cycles_from_node(&mut stack, processor, index),
+ _ => self.find_cycles_from_node(&mut stack, processor, i),
}
}
debug!("process_cycles: complete");
debug_assert!(stack.is_empty());
- self.scratch = Some(stack);
+ self.scratch.replace(stack);
}
- fn find_cycles_from_node<P>(&self, stack: &mut Vec<usize>,
- processor: &mut P, index: usize)
+ fn find_cycles_from_node<P>(&self, stack: &mut Vec<usize>, processor: &mut P, i: usize)
where P: ObligationProcessor<Obligation=O>
{
- let node = &self.nodes[index];
- let state = node.state.get();
- match state {
+ let node = &self.nodes[i];
+ match node.state.get() {
NodeState::OnDfsStack => {
- let index =
- stack.iter().rposition(|n| *n == index).unwrap();
- processor.process_backedge(stack[index..].iter().map(GetObligation(&self.nodes)),
+ let i = stack.iter().rposition(|n| *n == i).unwrap();
+ processor.process_backedge(stack[i..].iter().map(GetObligation(&self.nodes)),
PhantomData);
}
NodeState::Success => {
node.state.set(NodeState::OnDfsStack);
- stack.push(index);
- for dependent in node.parent.iter().chain(node.dependents.iter()) {
- self.find_cycles_from_node(stack, processor, dependent.get());
+ stack.push(i);
+ for index in node.parent.iter().chain(node.dependents.iter()) {
+ self.find_cycles_from_node(stack, processor, index.index());
}
stack.pop();
node.state.set(NodeState::Done);
},
NodeState::Waiting | NodeState::Pending => {
- // this node is still reachable from some pending node. We
+ // This node is still reachable from some pending node. We
// will get to it when they are all processed.
}
NodeState::Done | NodeState::Error => {
- // already processed that node
+ // Already processed that node.
}
};
}
/// Returns a vector of obligations for `p` and all of its
/// ancestors, putting them into the error state in the process.
- fn error_at(&mut self, p: usize) -> Vec<O> {
- let mut error_stack = self.scratch.take().unwrap();
+ fn error_at(&self, mut i: usize) -> Vec<O> {
+ let mut error_stack = self.scratch.replace(vec![]);
let mut trace = vec![];
- let mut n = p;
loop {
- self.nodes[n].state.set(NodeState::Error);
- trace.push(self.nodes[n].obligation.clone());
- error_stack.extend(self.nodes[n].dependents.iter().map(|x| x.get()));
+ let node = &self.nodes[i];
+ node.state.set(NodeState::Error);
+ trace.push(node.obligation.clone());
+ error_stack.extend(node.dependents.iter().map(|index| index.index()));
- // loop to the parent
- match self.nodes[n].parent {
- Some(q) => n = q.get(),
+ // Loop to the parent.
+ match node.parent {
+ Some(parent_index) => i = parent_index.index(),
None => break
}
}
while let Some(i) = error_stack.pop() {
- match self.nodes[i].state.get() {
+ let node = &self.nodes[i];
+ match node.state.get() {
NodeState::Error => continue,
- _ => self.nodes[i].state.set(NodeState::Error),
+ _ => node.state.set(NodeState::Error),
}
- let node = &self.nodes[i];
-
error_stack.extend(
- node.parent.iter().chain(node.dependents.iter()).map(|x| x.get())
+ node.parent.iter().chain(node.dependents.iter()).map(|index| index.index())
);
}
- self.scratch = Some(error_stack);
+ self.scratch.replace(error_stack);
trace
}
@@ -563,7 +578,7 @@
#[inline(always)]
fn inlined_mark_neighbors_as_waiting_from(&self, node: &Node<O>) {
for dependent in node.parent.iter().chain(node.dependents.iter()) {
- self.mark_as_waiting_from(&self.nodes[dependent.get()]);
+ self.mark_as_waiting_from(&self.nodes[dependent.index()]);
}
}
@@ -609,7 +624,7 @@
#[inline(never)]
fn compress(&mut self, do_completed: DoCompleted) -> Option<Vec<O>> {
let nodes_len = self.nodes.len();
- let mut node_rewrites: Vec<_> = self.scratch.take().unwrap();
+ let mut node_rewrites: Vec<_> = self.scratch.replace(vec![]);
node_rewrites.extend(0..nodes_len);
let mut dead_nodes = 0;
@@ -620,7 +635,8 @@
// self.nodes[i - dead_nodes..i] are all dead
// self.nodes[i..] are unchanged
for i in 0..self.nodes.len() {
- match self.nodes[i].state.get() {
+ let node = &self.nodes[i];
+ match node.state.get() {
NodeState::Pending | NodeState::Waiting => {
if dead_nodes > 0 {
self.nodes.swap(i, i - dead_nodes);
@@ -628,13 +644,16 @@
}
}
NodeState::Done => {
- // Avoid cloning the key (predicate) in case it exists in the waiting cache
+ // This lookup can fail because the contents of
+ // `self.waiting_cache` is not guaranteed to match those of
+ // `self.nodes`. See the comment in `process_obligation`
+ // for more details.
if let Some((predicate, _)) = self.waiting_cache
- .remove_entry(self.nodes[i].obligation.as_predicate())
+ .remove_entry(node.obligation.as_predicate())
{
self.done_cache.insert(predicate);
} else {
- self.done_cache.insert(self.nodes[i].obligation.as_predicate().clone());
+ self.done_cache.insert(node.obligation.as_predicate().clone());
}
node_rewrites[i] = nodes_len;
dead_nodes += 1;
@@ -643,7 +662,7 @@
// We *intentionally* remove the node from the cache at this point. Otherwise
// tests must come up with a different type on every type error they
// check against.
- self.waiting_cache.remove(self.nodes[i].obligation.as_predicate());
+ self.waiting_cache.remove(node.obligation.as_predicate());
node_rewrites[i] = nodes_len;
dead_nodes += 1;
self.insert_into_error_cache(i);
@@ -655,12 +674,11 @@
// No compression needed.
if dead_nodes == 0 {
node_rewrites.truncate(0);
- self.scratch = Some(node_rewrites);
+ self.scratch.replace(node_rewrites);
return if do_completed == DoCompleted::Yes { Some(vec![]) } else { None };
}
- // Pop off all the nodes we killed and extract the success
- // stories.
+ // Pop off all the nodes we killed and extract the success stories.
let successful = if do_completed == DoCompleted::Yes {
Some((0..dead_nodes)
.map(|_| self.nodes.pop().unwrap())
@@ -679,7 +697,7 @@
self.apply_rewrites(&node_rewrites);
node_rewrites.truncate(0);
- self.scratch = Some(node_rewrites);
+ self.scratch.replace(node_rewrites);
successful
}
@@ -689,58 +707,41 @@
for node in &mut self.nodes {
if let Some(index) = node.parent {
- let new_index = node_rewrites[index.get()];
- if new_index >= nodes_len {
- // parent dead due to error
+ let new_i = node_rewrites[index.index()];
+ if new_i >= nodes_len {
node.parent = None;
} else {
- node.parent = Some(NodeIndex::new(new_index));
+ node.parent = Some(NodeIndex::new(new_i));
}
}
let mut i = 0;
while i < node.dependents.len() {
- let new_index = node_rewrites[node.dependents[i].get()];
- if new_index >= nodes_len {
+ let new_i = node_rewrites[node.dependents[i].index()];
+ if new_i >= nodes_len {
node.dependents.swap_remove(i);
} else {
- node.dependents[i] = NodeIndex::new(new_index);
+ node.dependents[i] = NodeIndex::new(new_i);
i += 1;
}
}
}
- let mut kill_list = vec![];
- for (predicate, index) in &mut self.waiting_cache {
- let new_index = node_rewrites[index.get()];
- if new_index >= nodes_len {
- kill_list.push(predicate.clone());
+ // This updating of `self.waiting_cache` is necessary because the
+ // removal of nodes within `compress` can fail. See above.
+ self.waiting_cache.retain(|_predicate, index| {
+ let new_i = node_rewrites[index.index()];
+ if new_i >= nodes_len {
+ false
} else {
- *index = NodeIndex::new(new_index);
+ *index = NodeIndex::new(new_i);
+ true
}
- }
-
- for predicate in kill_list { self.waiting_cache.remove(&predicate); }
+ });
}
}
-impl<O> Node<O> {
- fn new(
- parent: Option<NodeIndex>,
- obligation: O,
- obligation_tree_id: ObligationTreeId
- ) -> Node<O> {
- Node {
- obligation,
- state: Cell::new(NodeState::Pending),
- parent,
- dependents: vec![],
- obligation_tree_id,
- }
- }
-}
-
-// I need a Clone closure
+// I need a Clone closure.
#[derive(Clone)]
struct GetObligation<'a, O>(&'a [Node<O>]);
diff --git a/src/librustc_data_structures/obligation_forest/node_index.rs b/src/librustc_data_structures/obligation_forest/node_index.rs
deleted file mode 100644
index 69ea473..0000000
--- a/src/librustc_data_structures/obligation_forest/node_index.rs
+++ /dev/null
@@ -1,20 +0,0 @@
-use std::num::NonZeroU32;
-use std::u32;
-
-#[derive(Copy, Clone, Debug, PartialEq, Eq)]
-pub struct NodeIndex {
- index: NonZeroU32,
-}
-
-impl NodeIndex {
- #[inline]
- pub fn new(value: usize) -> NodeIndex {
- assert!(value < (u32::MAX as usize));
- NodeIndex { index: NonZeroU32::new((value as u32) + 1).unwrap() }
- }
-
- #[inline]
- pub fn get(self) -> usize {
- (self.index.get() - 1) as usize
- }
-}