| use crate::signal::Signal; |
| use salsa::Database; |
| use salsa::ParallelDatabase; |
| use salsa::Snapshot; |
| use std::sync::Arc; |
| use std::{ |
| cell::Cell, |
| panic::{catch_unwind, resume_unwind, AssertUnwindSafe}, |
| }; |
| |
| #[salsa::query_group(Par)] |
| pub(crate) trait ParDatabase: Knobs { |
| #[salsa::input] |
| fn input(&self, key: char) -> usize; |
| |
| fn sum(&self, key: &'static str) -> usize; |
| |
| /// Invokes `sum` |
| fn sum2(&self, key: &'static str) -> usize; |
| |
| /// Invokes `sum` but doesn't really care about the result. |
| fn sum2_drop_sum(&self, key: &'static str) -> usize; |
| |
| /// Invokes `sum2` |
| fn sum3(&self, key: &'static str) -> usize; |
| |
| /// Invokes `sum2_drop_sum` |
| fn sum3_drop_sum(&self, key: &'static str) -> usize; |
| } |
| |
| /// Various "knobs" and utilities used by tests to force |
| /// a certain behavior. |
| pub(crate) trait Knobs { |
| fn knobs(&self) -> &KnobsStruct; |
| |
| fn signal(&self, stage: usize); |
| |
| fn wait_for(&self, stage: usize); |
| } |
| |
| pub(crate) trait WithValue<T> { |
| fn with_value<R>(&self, value: T, closure: impl FnOnce() -> R) -> R; |
| } |
| |
| impl<T> WithValue<T> for Cell<T> { |
| fn with_value<R>(&self, value: T, closure: impl FnOnce() -> R) -> R { |
| let old_value = self.replace(value); |
| |
| let result = catch_unwind(AssertUnwindSafe(closure)); |
| |
| self.set(old_value); |
| |
| match result { |
| Ok(r) => r, |
| Err(payload) => resume_unwind(payload), |
| } |
| } |
| } |
| |
| #[derive(Default, Clone, Copy, PartialEq, Eq)] |
| pub(crate) enum CancellationFlag { |
| #[default] |
| Down, |
| Panic, |
| } |
| |
| /// Various "knobs" that can be used to customize how the queries |
| /// behave on one specific thread. Note that this state is |
| /// intentionally thread-local (apart from `signal`). |
| #[derive(Clone, Default)] |
| pub(crate) struct KnobsStruct { |
| /// A kind of flexible barrier used to coordinate execution across |
| /// threads to ensure we reach various weird states. |
| pub(crate) signal: Arc<Signal>, |
| |
| /// When this database is about to block, send a signal. |
| pub(crate) signal_on_will_block: Cell<usize>, |
| |
| /// Invocations of `sum` will signal this stage on entry. |
| pub(crate) sum_signal_on_entry: Cell<usize>, |
| |
| /// Invocations of `sum` will wait for this stage on entry. |
| pub(crate) sum_wait_for_on_entry: Cell<usize>, |
| |
| /// If true, invocations of `sum` will panic before they exit. |
| pub(crate) sum_should_panic: Cell<bool>, |
| |
| /// If true, invocations of `sum` will wait for cancellation before |
| /// they exit. |
| pub(crate) sum_wait_for_cancellation: Cell<CancellationFlag>, |
| |
| /// Invocations of `sum` will wait for this stage prior to exiting. |
| pub(crate) sum_wait_for_on_exit: Cell<usize>, |
| |
| /// Invocations of `sum` will signal this stage prior to exiting. |
| pub(crate) sum_signal_on_exit: Cell<usize>, |
| |
| /// Invocations of `sum3_drop_sum` will panic unconditionally |
| pub(crate) sum3_drop_sum_should_panic: Cell<bool>, |
| } |
| |
| fn sum(db: &dyn ParDatabase, key: &'static str) -> usize { |
| let mut sum = 0; |
| |
| db.signal(db.knobs().sum_signal_on_entry.get()); |
| |
| db.wait_for(db.knobs().sum_wait_for_on_entry.get()); |
| |
| if db.knobs().sum_should_panic.get() { |
| panic!("query set to panic before exit") |
| } |
| |
| for ch in key.chars() { |
| sum += db.input(ch); |
| } |
| |
| match db.knobs().sum_wait_for_cancellation.get() { |
| CancellationFlag::Down => (), |
| CancellationFlag::Panic => { |
| tracing::debug!("waiting for cancellation"); |
| loop { |
| db.unwind_if_cancelled(); |
| std::thread::yield_now(); |
| } |
| } |
| } |
| |
| db.wait_for(db.knobs().sum_wait_for_on_exit.get()); |
| |
| db.signal(db.knobs().sum_signal_on_exit.get()); |
| |
| sum |
| } |
| |
| fn sum2(db: &dyn ParDatabase, key: &'static str) -> usize { |
| db.sum(key) |
| } |
| |
| fn sum2_drop_sum(db: &dyn ParDatabase, key: &'static str) -> usize { |
| let _ = db.sum(key); |
| 22 |
| } |
| |
| fn sum3(db: &dyn ParDatabase, key: &'static str) -> usize { |
| db.sum2(key) |
| } |
| |
| fn sum3_drop_sum(db: &dyn ParDatabase, key: &'static str) -> usize { |
| if db.knobs().sum3_drop_sum_should_panic.get() { |
| panic!("sum3_drop_sum executed") |
| } |
| db.sum2_drop_sum(key) |
| } |
| |
| #[salsa::database( |
| Par, |
| crate::parallel_cycle_all_recover::ParallelCycleAllRecover, |
| crate::parallel_cycle_none_recover::ParallelCycleNoneRecover, |
| crate::parallel_cycle_mid_recover::ParallelCycleMidRecovers, |
| crate::parallel_cycle_one_recovers::ParallelCycleOneRecovers |
| )] |
| #[derive(Default)] |
| pub(crate) struct ParDatabaseImpl { |
| storage: salsa::Storage<Self>, |
| knobs: KnobsStruct, |
| } |
| |
| impl Database for ParDatabaseImpl { |
| fn salsa_event(&self, event: salsa::Event) { |
| if let salsa::EventKind::WillBlockOn { .. } = event.kind { |
| self.signal(self.knobs().signal_on_will_block.get()); |
| } |
| } |
| } |
| |
| impl ParallelDatabase for ParDatabaseImpl { |
| fn snapshot(&self) -> Snapshot<Self> { |
| Snapshot::new(ParDatabaseImpl { |
| storage: self.storage.snapshot(), |
| knobs: self.knobs.clone(), |
| }) |
| } |
| } |
| |
| impl Knobs for ParDatabaseImpl { |
| fn knobs(&self) -> &KnobsStruct { |
| &self.knobs |
| } |
| |
| fn signal(&self, stage: usize) { |
| self.knobs.signal.signal(stage); |
| } |
| |
| fn wait_for(&self, stage: usize) { |
| self.knobs.signal.wait_for(stage); |
| } |
| } |