| // Copyright 2023 The Fuchsia Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| use crate::task::CurrentTask; |
| use starnix_uapi::errors::Errno; |
| |
| use core::marker::PhantomData; |
| |
| use starnix_sync::{InterruptibleEvent, LockBefore, Locked, Mutex}; |
| use std::collections::VecDeque; |
| use std::sync::Arc; |
| |
| use lock_api as _; |
| |
| #[cfg(any(test, debug_assertions))] |
| use lock_api::RawRwLock; |
| |
| #[derive(Debug)] |
| pub struct RwQueue<L> { |
| inner: Mutex<RwQueueInner>, |
| _phantom: PhantomData<L>, |
| |
| // Used to inform our deadlock detector about the waiters in the queue. |
| #[cfg(any(test, debug_assertions))] |
| tracer: tracer::MutexTracer, |
| } |
| |
| impl<L> RwQueue<L> { |
| // Acquires a read lock without checking lock ordering. |
| // TODO(https://fxbug.dev/333540469): This should be a part of the implementation |
| // of an OrderedRwLock. However, this requires that OrderedRwLock accepts the |
| // `read()` method that uses a context (in this case, `CurrentTask`). |
| fn read_internal(&self, current_task: &CurrentTask) -> Result<(), Errno> { |
| #[cfg(any(test, debug_assertions))] |
| self.tracer.lock_shared(); |
| |
| let mut inner = self.inner.lock(); |
| |
| if !inner.try_read() { |
| let event = InterruptibleEvent::new(); |
| let guard = event.begin_wait(); |
| |
| inner.waiters.push_back(Waiter::Reader(event.clone())); |
| |
| std::mem::drop(inner); |
| |
| current_task.block_until(guard, zx::MonotonicInstant::INFINITE).map_err(|e| { |
| self.inner.lock().remove_waiter(&event); |
| e |
| })?; |
| } |
| Ok(()) |
| } |
| |
| pub fn read_and<'a, P>( |
| &'a self, |
| locked: &'a mut Locked<P>, |
| current_task: &CurrentTask, |
| ) -> Result<(RwQueueReadGuard<'a, L>, &'a mut Locked<L>), Errno> |
| where |
| P: LockBefore<L>, |
| { |
| self.read_internal(current_task)?; |
| |
| let new_locked = locked.cast_locked::<L>(); |
| |
| Ok((RwQueueReadGuard { queue: self }, new_locked)) |
| } |
| |
| pub fn write_and<'a, P>( |
| &'a self, |
| locked: &'a mut Locked<P>, |
| current_task: &CurrentTask, |
| ) -> Result<(RwQueueWriteGuard<'a, L>, &'a mut Locked<L>), Errno> |
| where |
| P: LockBefore<L>, |
| { |
| #[cfg(any(test, debug_assertions))] |
| self.tracer.lock_exclusive(); |
| |
| let mut inner = self.inner.lock(); |
| |
| if !inner.try_write() { |
| let event = InterruptibleEvent::new(); |
| let guard = event.begin_wait(); |
| |
| inner.waiters.push_back(Waiter::Writer(event.clone())); |
| |
| std::mem::drop(inner); |
| |
| current_task.block_until(guard, zx::MonotonicInstant::INFINITE).map_err(|e| { |
| self.inner.lock().remove_waiter(&event); |
| e |
| })?; |
| } |
| |
| let new_locked = locked.cast_locked::<L>(); |
| Ok((RwQueueWriteGuard { queue: self }, new_locked)) |
| } |
| |
| pub fn read<'a, P>( |
| &'a self, |
| locked: &'a mut Locked<P>, |
| current_task: &CurrentTask, |
| ) -> Result<RwQueueReadGuard<'a, L>, Errno> |
| where |
| P: LockBefore<L>, |
| { |
| self.read_and(locked, current_task).map(|(g, _)| g) |
| } |
| |
| pub fn write<'a, P>( |
| &'a self, |
| locked: &'a mut Locked<P>, |
| current_task: &CurrentTask, |
| ) -> Result<RwQueueWriteGuard<'_, L>, Errno> |
| where |
| P: LockBefore<L>, |
| { |
| self.write_and(locked, current_task).map(|(g, _)| g) |
| } |
| |
| /// Used to establish lock ordering. |
| #[cfg(any(test, debug_assertions))] |
| pub fn read_for_lock_ordering<'a, P>( |
| &'a self, |
| locked: &'a mut Locked<P>, |
| ) -> (RwQueueReadGuard<'a, L>, &'a mut Locked<L>) |
| where |
| P: LockBefore<L>, |
| { |
| #[cfg(any(test, debug_assertions))] |
| self.tracer.lock_shared(); |
| |
| assert!(self.inner.lock().try_read(), "Cannot fail to acquire a read for lock ordering."); |
| let new_locked = locked.cast_locked::<L>(); |
| |
| (RwQueueReadGuard { queue: self }, new_locked) |
| } |
| |
| fn unlock_read(&self) { |
| self.inner.lock().unlock_read(); |
| |
| #[allow( |
| clippy::undocumented_unsafe_blocks, |
| reason = "Force documented unsafe blocks in Starnix" |
| )] |
| #[cfg(any(test, debug_assertions))] |
| unsafe { |
| self.tracer.unlock_shared(); |
| } |
| } |
| |
| fn unlock_write(&self) { |
| self.inner.lock().unlock_write(); |
| |
| #[allow( |
| clippy::undocumented_unsafe_blocks, |
| reason = "Force documented unsafe blocks in Starnix" |
| )] |
| #[cfg(any(test, debug_assertions))] |
| unsafe { |
| self.tracer.unlock_exclusive(); |
| } |
| } |
| } |
| |
| impl<L> Default for RwQueue<L> { |
| fn default() -> Self { |
| Self { |
| inner: Default::default(), |
| #[cfg(any(test, debug_assertions))] |
| tracer: Default::default(), |
| _phantom: Default::default(), |
| } |
| } |
| } |
| |
| /// The queue is ready for any operation. |
| const READY: usize = 0; |
| |
| /// The queue has exactly one writer. |
| const WRITER: usize = 0b01; |
| |
| /// Each writer in the queue increments the state by this amount. |
| const READER: usize = 0b10; |
| |
| /// A writer is currently running. |
| fn has_writer(state: usize) -> bool { |
| state & WRITER != 0 |
| } |
| |
| /// At elast one reader is currently running. |
| fn has_reader(state: usize) -> bool { |
| state >= READER |
| } |
| |
| fn debug_assert_consistent(state: usize) { |
| debug_assert!(!has_writer(state) || !has_reader(state)); |
| } |
| |
| #[derive(Debug, Clone)] |
| enum Waiter { |
| Reader(Arc<InterruptibleEvent>), |
| Writer(Arc<InterruptibleEvent>), |
| } |
| |
| #[derive(Debug, Default)] |
| struct RwQueueInner { |
| /// What operations are currently ongoing. |
| /// |
| /// See READY, READER, WRITER above for what these bits mean. |
| state: usize, |
| |
| /// The operations that are waiting for the ongoing operations to complete. |
| waiters: VecDeque<Waiter>, |
| } |
| |
| impl RwQueueInner { |
| fn has_waiters(&self) -> bool { |
| !self.waiters.is_empty() |
| } |
| |
| fn try_read(&mut self) -> bool { |
| debug_assert_consistent(self.state); |
| if !has_writer(self.state) && !self.has_waiters() { |
| if let Some(new_state) = self.state.checked_add(READER) { |
| self.state = new_state; |
| return true; |
| } |
| } |
| false |
| } |
| |
| fn try_write(&mut self) -> bool { |
| debug_assert_consistent(self.state); |
| if self.state == READY && !self.has_waiters() { |
| self.state += WRITER; |
| true |
| } else { |
| false |
| } |
| } |
| |
| fn unlock_read(&mut self) { |
| debug_assert!(has_reader(self.state) && !has_writer(self.state)); |
| self.state -= READER; |
| |
| if !has_reader(self.state) && self.has_waiters() { |
| self.notify_next(); |
| } |
| } |
| |
| fn unlock_write(&mut self) { |
| debug_assert!(has_writer(self.state) && !has_reader(self.state)); |
| self.state -= WRITER; |
| |
| if self.has_waiters() { |
| self.notify_next(); |
| } |
| } |
| |
| fn notify_next(&mut self) { |
| while let Some(waiter) = self.waiters.front() { |
| match waiter { |
| Waiter::Reader(reader) => { |
| if has_writer(self.state) { |
| return; |
| } |
| // We need to use `checked_add` to ensure we do not |
| // overflow the number of readers. If that happens, we just |
| // need to wait for the enormous number of readers to finish. |
| let Some(new_state) = self.state.checked_add(READER) else { |
| return; |
| }; |
| self.state = new_state; |
| reader.notify(); |
| } |
| Waiter::Writer(writer) => { |
| if has_reader(self.state) || has_writer(self.state) { |
| return; |
| } |
| // We can never overflow writers because we only let one |
| // through at a time. |
| self.state += WRITER; |
| writer.notify(); |
| } |
| } |
| self.waiters.pop_front(); |
| } |
| debug_assert_consistent(self.state); |
| } |
| |
| fn remove_waiter(&mut self, event: &Arc<InterruptibleEvent>) { |
| self.waiters.retain(|waiter| { |
| let (Waiter::Reader(other) | Waiter::Writer(other)) = waiter; |
| !Arc::ptr_eq(event, other) |
| }); |
| } |
| } |
| |
| pub struct RwQueueReadGuard<'a, L> { |
| queue: &'a RwQueue<L>, |
| } |
| |
| impl<'a, L> Drop for RwQueueReadGuard<'a, L> { |
| fn drop(&mut self) { |
| self.queue.unlock_read(); |
| } |
| } |
| |
| pub struct RwQueueWriteGuard<'a, L> { |
| queue: &'a RwQueue<L>, |
| } |
| |
| impl<'a, L> Drop for RwQueueWriteGuard<'a, L> { |
| fn drop(&mut self) { |
| self.queue.unlock_write(); |
| } |
| } |
| |
| #[cfg(any(test, debug_assertions))] |
| mod tracer { |
| |
| #[derive(Debug, Default)] |
| pub struct FakeRwLock {} |
| |
| #[allow( |
| clippy::undocumented_unsafe_blocks, |
| reason = "Force documented unsafe blocks in Starnix" |
| )] |
| unsafe impl lock_api::RawRwLock for FakeRwLock { |
| const INIT: Self = Self {}; |
| |
| type GuardMarker = lock_api::GuardNoSend; |
| |
| fn lock_shared(&self) {} |
| fn try_lock_shared(&self) -> bool { |
| false |
| } |
| unsafe fn unlock_shared(&self) {} |
| |
| fn lock_exclusive(&self) {} |
| fn try_lock_exclusive(&self) -> bool { |
| false |
| } |
| unsafe fn unlock_exclusive(&self) {} |
| |
| fn is_locked(&self) -> bool { |
| false |
| } |
| } |
| |
| // We should replace this type with tracing_mutex::MutexId once that type is public. |
| pub type MutexTracer = tracing_mutex::lockapi::TracingWrapper<FakeRwLock>; |
| } |
| |
| // We use tracing_mutex in tests and debug assertions, but we don't want to pull it in for |
| // production. |
| #[cfg(not(any(test, debug_assertions)))] |
| use tracing_mutex as _; |
| |
| #[cfg(test)] |
| mod test { |
| use super::*; |
| use crate::task::Kernel; |
| use crate::task::dynamic_thread_spawner::SpawnRequestBuilder; |
| use crate::testing::*; |
| use futures::executor::block_on; |
| use futures::future::join_all; |
| use starnix_sync::{Unlocked, lock_ordering}; |
| use std::future::Future; |
| use std::pin::Pin; |
| use std::sync::Barrier; |
| use std::sync::atomic::{AtomicUsize, Ordering}; |
| |
| #[::fuchsia::test] |
| fn test_remove_from_queue() { |
| let mut inner = RwQueueInner::default(); |
| let event1 = InterruptibleEvent::new(); |
| let event2 = InterruptibleEvent::new(); |
| let event3 = InterruptibleEvent::new(); |
| inner.waiters.push_back(Waiter::Writer(event1.clone())); |
| inner.waiters.push_back(Waiter::Writer(event2.clone())); |
| inner.waiters.push_back(Waiter::Writer(event3.clone())); |
| |
| inner.remove_waiter(&event2); |
| |
| let waiter = inner.waiters.pop_front().expect("should have a waiter"); |
| let Waiter::Writer(event) = waiter else { |
| unreachable!(); |
| }; |
| assert!(Arc::ptr_eq(&event1, &event)); |
| |
| let waiter = inner.waiters.pop_front().expect("should have a waiter"); |
| let Waiter::Writer(event) = waiter else { |
| unreachable!(); |
| }; |
| assert!(Arc::ptr_eq(&event3, &event)); |
| |
| assert!(inner.waiters.is_empty()); |
| } |
| |
| #[::fuchsia::test] |
| async fn test_write_and_read() { |
| lock_ordering! { |
| Unlocked => TestLevel |
| } |
| |
| spawn_kernel_and_run(async |locked, current_task| { |
| let queue = RwQueue::<TestLevel>::default(); |
| let read_guard1 = queue.read(locked, current_task).expect("shouldn't be interrupted"); |
| std::mem::drop(read_guard1); |
| |
| let write_guard = queue.write(locked, current_task).expect("shouldn't be interrupted"); |
| std::mem::drop(write_guard); |
| |
| let read_guard2 = queue.read(locked, current_task).expect("shouldn't be interrupted"); |
| std::mem::drop(read_guard2); |
| }) |
| .await; |
| } |
| |
| #[::fuchsia::test] |
| async fn test_read_in_parallel() { |
| spawn_kernel_and_run(async |_, current_task| { |
| let kernel = current_task.kernel(); |
| lock_ordering! { |
| Unlocked => TestLevel |
| } |
| struct Info { |
| barrier: Barrier, |
| queue: RwQueue<TestLevel>, |
| } |
| |
| let info = |
| Arc::new(Info { barrier: Barrier::new(2), queue: RwQueue::<TestLevel>::default() }); |
| |
| let info1 = Arc::clone(&info); |
| let closure1 = move |locked: &mut Locked<Unlocked>, current_task: &CurrentTask| { |
| let guard = |
| info1.queue.read(locked, current_task).expect("shouldn't be interrupted"); |
| info1.barrier.wait(); |
| std::mem::drop(guard); |
| }; |
| let (thread1, req) = |
| SpawnRequestBuilder::new().with_sync_closure(closure1).build_with_async_result(); |
| kernel.kthreads.spawner().spawn_from_request(req); |
| |
| let info2 = Arc::clone(&info); |
| let closure2 = move |locked: &mut Locked<Unlocked>, current_task: &CurrentTask| { |
| let guard = |
| info2.queue.read(locked, current_task).expect("shouldn't be interrupted"); |
| info2.barrier.wait(); |
| std::mem::drop(guard); |
| }; |
| let (thread2, req) = |
| SpawnRequestBuilder::new().with_sync_closure(closure2).build_with_async_result(); |
| kernel.kthreads.spawner().spawn_from_request(req); |
| |
| block_on(async { |
| thread1.await.expect("failed to join thread"); |
| thread2.await.expect("failed to join thread"); |
| }); |
| }) |
| .await; |
| } |
| |
| lock_ordering! { |
| Unlocked => A |
| } |
| struct State { |
| queue: RwQueue<A>, |
| gate: Barrier, |
| writer_count: AtomicUsize, |
| reader_count: AtomicUsize, |
| } |
| |
| impl State { |
| fn new(n: usize) -> State { |
| State { |
| queue: Default::default(), |
| gate: Barrier::new(n), |
| writer_count: Default::default(), |
| reader_count: Default::default(), |
| } |
| } |
| |
| fn spawn_writer( |
| state: Arc<Self>, |
| kernel: Arc<Kernel>, |
| count: usize, |
| ) -> Pin<Box<dyn Future<Output = Result<(), Errno>> + Send>> { |
| let closure = move |locked: &mut Locked<Unlocked>, current_task: &CurrentTask| { |
| state.gate.wait(); |
| for _ in 0..count { |
| let guard = |
| state.queue.write(locked, current_task).expect("shouldn't be interrupted"); |
| let writer_count = state.writer_count.fetch_add(1, Ordering::Acquire) + 1; |
| let reader_count = state.reader_count.load(Ordering::Acquire); |
| state.writer_count.fetch_sub(1, Ordering::Release); |
| std::mem::drop(guard); |
| assert_eq!(writer_count, 1, "More than one writer held the lock at once."); |
| assert_eq!( |
| reader_count, 0, |
| "A reader and writer held the lock at the same time." |
| ); |
| } |
| }; |
| let (result, req) = |
| SpawnRequestBuilder::new().with_sync_closure(closure).build_with_async_result(); |
| kernel.kthreads.spawner().spawn_from_request(req); |
| Box::pin(result) |
| } |
| |
| fn spawn_reader( |
| state: Arc<Self>, |
| kernel: Arc<Kernel>, |
| count: usize, |
| ) -> Pin<Box<dyn Future<Output = Result<(), Errno>> + Send>> { |
| let closure = move |locked: &mut Locked<Unlocked>, current_task: &CurrentTask| { |
| state.gate.wait(); |
| for _ in 0..count { |
| let guard = |
| state.queue.read(locked, current_task).expect("shouldn't be interrupted"); |
| let reader_count = state.reader_count.fetch_add(1, Ordering::Acquire) + 1; |
| let writer_count = state.writer_count.load(Ordering::Acquire); |
| state.reader_count.fetch_sub(1, Ordering::Release); |
| std::mem::drop(guard); |
| assert_eq!( |
| writer_count, 0, |
| "A reader and writer held the lock at the same time." |
| ); |
| assert!(reader_count > 0, "A reader held the lock without being counted."); |
| } |
| }; |
| let (result, req) = |
| SpawnRequestBuilder::new().with_sync_closure(closure).build_with_async_result(); |
| kernel.kthreads.spawner().spawn_from_request(req); |
| Box::pin(result) |
| } |
| } |
| |
| #[::fuchsia::test] |
| async fn test_thundering_reads_and_writes() { |
| spawn_kernel_and_run(async |_, current_task| { |
| let kernel = current_task.kernel(); |
| const THREAD_PAIRS: usize = 10; |
| |
| let state = Arc::new(State::new(THREAD_PAIRS * 2)); |
| let mut threads = vec![]; |
| for _ in 0..THREAD_PAIRS { |
| threads.push(State::spawn_writer(Arc::clone(&state), kernel.clone(), 100)); |
| threads.push(State::spawn_reader(Arc::clone(&state), kernel.clone(), 100)); |
| } |
| |
| block_on(join_all(threads)).into_iter().for_each(|r| r.expect("failed to join thread")); |
| }) |
| .await; |
| } |
| } |