blob: 2b989b55e92268f51e79792ffe61d048db37a033 [file] [log] [blame]
//! A simple mutex implementation.
//!
//! This mutex exposes both blocking and async methods for acquiring a lock.
#![allow(dead_code)]
use std::cell::UnsafeCell;
use std::ops::{Deref, DerefMut};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{mpsc, Arc};
use std::thread;
use std::time::{Duration, Instant};
use event_listener::Event;
/// A simple mutex.
struct Mutex<T> {
/// Set to `true` when the mutex is locked.
locked: AtomicBool,
/// Blocked lock operations.
lock_ops: Event,
/// The inner protected data.
data: UnsafeCell<T>,
}
unsafe impl<T: Send> Send for Mutex<T> {}
unsafe impl<T: Send> Sync for Mutex<T> {}
impl<T> Mutex<T> {
/// Creates a mutex.
fn new(t: T) -> Mutex<T> {
Mutex {
locked: AtomicBool::new(false),
lock_ops: Event::new(),
data: UnsafeCell::new(t),
}
}
/// Attempts to acquire a lock.
fn try_lock(&self) -> Option<MutexGuard<'_, T>> {
if !self.locked.swap(true, Ordering::Acquire) {
Some(MutexGuard(self))
} else {
None
}
}
/// Blocks until a lock is acquired.
fn lock(&self) -> MutexGuard<'_, T> {
let mut listener = None;
loop {
// Attempt grabbing a lock.
if let Some(guard) = self.try_lock() {
return guard;
}
// Set up an event listener or wait for a notification.
match listener.take() {
None => {
// Start listening and then try locking again.
listener = Some(self.lock_ops.listen());
}
Some(l) => {
// Wait until a notification is received.
l.wait();
}
}
}
}
/// Blocks until a lock is acquired or the timeout is reached.
fn lock_timeout(&self, timeout: Duration) -> Option<MutexGuard<'_, T>> {
let deadline = Instant::now() + timeout;
let mut listener = None;
loop {
// Attempt grabbing a lock.
if let Some(guard) = self.try_lock() {
return Some(guard);
}
// Set up an event listener or wait for an event.
match listener.take() {
None => {
// Start listening and then try locking again.
listener = Some(self.lock_ops.listen());
}
Some(l) => {
// Wait until a notification is received.
if !l.wait_deadline(deadline) {
return None;
}
}
}
}
}
/// Acquires a lock asynchronously.
async fn lock_async(&self) -> MutexGuard<'_, T> {
let mut listener = None;
loop {
// Attempt grabbing a lock.
if let Some(guard) = self.try_lock() {
return guard;
}
// Set up an event listener or wait for an event.
match listener.take() {
None => {
// Start listening and then try locking again.
listener = Some(self.lock_ops.listen());
}
Some(l) => {
// Wait until a notification is received.
l.await;
}
}
}
}
}
/// A guard holding a lock.
struct MutexGuard<'a, T>(&'a Mutex<T>);
unsafe impl<T: Send> Send for MutexGuard<'_, T> {}
unsafe impl<T: Sync> Sync for MutexGuard<'_, T> {}
impl<T> Deref for MutexGuard<'_, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.0.data.get() }
}
}
impl<T> DerefMut for MutexGuard<'_, T> {
fn deref_mut(&mut self) -> &mut T {
unsafe { &mut *self.0.data.get() }
}
}
fn main() {
const N: usize = 10;
// A shared counter.
let counter = Arc::new(Mutex::new(0));
// A channel that signals when all threads are done.
let (tx, rx) = mpsc::channel();
// Spawn a bunch of threads incrementing the counter.
for _ in 0..N {
let counter = counter.clone();
let tx = tx.clone();
thread::spawn(move || {
let mut counter = counter.lock();
*counter += 1;
// If this is the last increment, signal that we're done.
if *counter == N {
tx.send(()).unwrap();
}
});
}
// Wait until the last thread increments the counter.
rx.recv().unwrap();
// The counter must equal the number of threads.
assert_eq!(*counter.lock(), N);
println!("Done!");
}