blob: b82cb59442dae41c0b22bbd323c7a5a0f6b9ce12 [file] [log] [blame]
// Copyright 2021 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//! Exposes the OnceCell crate for use in async code.
use async_lock::Mutex;
use once_cell::sync::OnceCell;
use std::future::Future;
/// Wrapper presenting an async interface to a OnceCell.
#[derive(Debug)]
pub struct Once<T> {
mutex: Mutex<()>,
value: OnceCell<T>,
}
impl<T> Default for Once<T> {
fn default() -> Self {
Self { mutex: Mutex::new(()), value: OnceCell::new() }
}
}
impl<T> Once<T> {
/// Constructor.
pub fn new() -> Self {
Self { mutex: Mutex::new(()), value: OnceCell::new() }
}
/// Async wrapper around OnceCell's `get_or_init`.
pub async fn get_or_init<'a, F>(&'a self, fut: F) -> &'a T
where
F: Future<Output = T>,
{
if let Some(t) = self.value.get() {
t
} else {
let _mut = self.mutex.lock().await;
// Someone raced us and just released the lock
if let Some(t) = self.value.get() {
t
} else {
let t = fut.await;
self.value.set(t).unwrap_or_else(|_| panic!("race in async-cell!"));
self.value.get().unwrap()
}
}
}
/// Async wrapper around OnceCell's `get_or_try_init`.
pub async fn get_or_try_init<'a, F, E>(&'a self, fut: F) -> Result<&'a T, E>
where
F: Future<Output = Result<T, E>>,
{
if let Some(t) = self.value.get() {
Ok(t)
} else {
let _mut = self.mutex.lock().await;
// Someone raced us and just released the lock
if let Some(t) = self.value.get() {
Ok(t)
} else {
let r = fut.await;
match r {
Ok(t) => {
self.value.set(t).unwrap_or_else(|_| panic!("race in async-cell!"));
Ok(self.value.get().unwrap())
}
Err(e) => Err(e),
}
}
}
}
}
#[cfg(test)]
mod test {
use super::*;
use futures_lite::future::block_on;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
#[test]
fn test_get_or_init() {
lazy_static::lazy_static!(
static ref ONCE: Once<bool> = Once::new();
);
static COUNTER: AtomicUsize = AtomicUsize::new(0);
let val = block_on(ONCE.get_or_init(async {
let _: usize = COUNTER.fetch_add(1, Ordering::SeqCst);
true
}));
assert_eq!(*val, true);
assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
let val = block_on(ONCE.get_or_init(async {
let _: usize = COUNTER.fetch_add(1, Ordering::SeqCst);
false
}));
assert_eq!(*val, true);
assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
}
#[test]
fn test_get_or_init_default_initializer() {
lazy_static::lazy_static!(
static ref ONCE: Once<bool> = Once::default();
);
static COUNTER: AtomicUsize = AtomicUsize::new(0);
let val = block_on(ONCE.get_or_init(async {
let _: usize = COUNTER.fetch_add(1, Ordering::SeqCst);
true
}));
assert_eq!(*val, true);
assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
let val = block_on(ONCE.get_or_init(async {
let _: usize = COUNTER.fetch_add(1, Ordering::SeqCst);
false
}));
assert_eq!(*val, true);
assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
}
#[test]
fn test_get_or_try_init() {
lazy_static::lazy_static!(
static ref ONCE: Once<bool> = Once::new();
);
static COUNTER: AtomicUsize = AtomicUsize::new(0);
let initializer = || async {
let val = COUNTER.fetch_add(1, Ordering::SeqCst);
if val == 0 {
Err(std::io::Error::new(std::io::ErrorKind::Other, "first attempt fails"))
} else {
Ok(true)
}
};
let val = block_on(ONCE.get_or_try_init(initializer()));
assert!(val.is_err());
assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
// The initializer gets another chance to run because the first attempt failed.
let val = block_on(ONCE.get_or_try_init(initializer()));
assert_eq!(*val.unwrap(), true);
assert_eq!(COUNTER.load(Ordering::SeqCst), 2);
// The initializer never runs again...
let val = block_on(ONCE.get_or_try_init(initializer()));
assert_eq!(*val.unwrap(), true);
assert_eq!(COUNTER.load(Ordering::SeqCst), 2);
}
}