| // Copyright 2017 Amanieu d'Antras |
| // |
| // Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or |
| // http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or |
| // http://opensource.org/licenses/MIT>, at your option. This file may not be |
| // copied, modified, or distributed except according to those terms. |
| |
| //! Per-object thread-local storage |
| //! |
| //! This library provides the `ThreadLocal` type which allows a separate copy of |
| //! an object to be used for each thread. This allows for per-object |
| //! thread-local storage, unlike the standard library's `thread_local!` macro |
| //! which only allows static thread-local storage. |
| //! |
| //! Per-thread objects are not destroyed when a thread exits. Instead, objects |
| //! are only destroyed when the `ThreadLocal` containing them is destroyed. |
| //! |
| //! You can also iterate over the thread-local values of all thread in a |
| //! `ThreadLocal` object using the `iter_mut` and `into_iter` methods. This can |
| //! only be done if you have mutable access to the `ThreadLocal` object, which |
| //! guarantees that you are the only thread currently accessing it. |
| //! |
| //! A `CachedThreadLocal` type is also provided which wraps a `ThreadLocal` but |
| //! also uses a special fast path for the first thread that writes into it. The |
| //! fast path has very low overhead (<1ns per access) while keeping the same |
| //! performance as `ThreadLocal` for other threads. |
| //! |
| //! Note that since thread IDs are recycled when a thread exits, it is possible |
| //! for one thread to retrieve the object of another thread. Since this can only |
| //! occur after a thread has exited this does not lead to any race conditions. |
| //! |
| //! # Examples |
| //! |
| //! Basic usage of `ThreadLocal`: |
| //! |
| //! ```rust |
| //! use thread_local::ThreadLocal; |
| //! let tls: ThreadLocal<u32> = ThreadLocal::new(); |
| //! assert_eq!(tls.get(), None); |
| //! assert_eq!(tls.get_or(|| 5), &5); |
| //! assert_eq!(tls.get(), Some(&5)); |
| //! ``` |
| //! |
| //! Combining thread-local values into a single result: |
| //! |
| //! ```rust |
| //! use thread_local::ThreadLocal; |
| //! use std::sync::Arc; |
| //! use std::cell::Cell; |
| //! use std::thread; |
| //! |
| //! let tls = Arc::new(ThreadLocal::new()); |
| //! |
| //! // Create a bunch of threads to do stuff |
| //! for _ in 0..5 { |
| //! let tls2 = tls.clone(); |
| //! thread::spawn(move || { |
| //! // Increment a counter to count some event... |
| //! let cell = tls2.get_or(|| Cell::new(0)); |
| //! cell.set(cell.get() + 1); |
| //! }).join().unwrap(); |
| //! } |
| //! |
| //! // Once all threads are done, collect the counter values and return the |
| //! // sum of all thread-local counter values. |
| //! let tls = Arc::try_unwrap(tls).unwrap(); |
| //! let total = tls.into_iter().fold(0, |x, y| x + y.get()); |
| //! assert_eq!(total, 5); |
| //! ``` |
| |
| #![warn(missing_docs)] |
| |
| #[macro_use] |
| extern crate lazy_static; |
| |
| mod thread_id; |
| mod unreachable; |
| mod cached; |
| |
| pub use cached::{CachedIntoIter, CachedIterMut, CachedThreadLocal}; |
| |
| use std::cell::UnsafeCell; |
| use std::fmt; |
| use std::marker::PhantomData; |
| use std::panic::UnwindSafe; |
| use std::sync::atomic::{AtomicPtr, AtomicUsize, Ordering}; |
| use std::sync::Mutex; |
| use unreachable::{UncheckedOptionExt, UncheckedResultExt}; |
| |
| /// Thread-local variable wrapper |
| /// |
| /// See the [module-level documentation](index.html) for more. |
| pub struct ThreadLocal<T: Send> { |
| // Pointer to the current top-level hash table |
| table: AtomicPtr<Table<T>>, |
| |
| // Lock used to guard against concurrent modifications. This is only taken |
| // while writing to the table, not when reading from it. This also guards |
| // the counter for the total number of values in the hash table. |
| lock: Mutex<usize>, |
| } |
| |
| struct Table<T: Send> { |
| // Hash entries for the table |
| entries: Box<[TableEntry<T>]>, |
| |
| // Number of bits used for the hash function |
| hash_bits: usize, |
| |
| // Previous table, half the size of the current one |
| prev: Option<Box<Table<T>>>, |
| } |
| |
| struct TableEntry<T: Send> { |
| // Current owner of this entry, or 0 if this is an empty entry |
| owner: AtomicUsize, |
| |
| // The object associated with this entry. This is only ever accessed by the |
| // owner of the entry. |
| data: UnsafeCell<Option<Box<T>>>, |
| } |
| |
| // ThreadLocal is always Sync, even if T isn't |
| unsafe impl<T: Send> Sync for ThreadLocal<T> {} |
| |
| impl<T: Send> Default for ThreadLocal<T> { |
| fn default() -> ThreadLocal<T> { |
| ThreadLocal::new() |
| } |
| } |
| |
| impl<T: Send> Drop for ThreadLocal<T> { |
| fn drop(&mut self) { |
| unsafe { |
| Box::from_raw(self.table.load(Ordering::Relaxed)); |
| } |
| } |
| } |
| |
| // Implementation of Clone for TableEntry, needed to make vec![] work |
| impl<T: Send> Clone for TableEntry<T> { |
| fn clone(&self) -> TableEntry<T> { |
| TableEntry { |
| owner: AtomicUsize::new(0), |
| data: UnsafeCell::new(None), |
| } |
| } |
| } |
| |
| // Hash function for the thread id |
| #[cfg(target_pointer_width = "32")] |
| #[inline] |
| fn hash(id: usize, bits: usize) -> usize { |
| id.wrapping_mul(0x9E3779B9) >> (32 - bits) |
| } |
| #[cfg(target_pointer_width = "64")] |
| #[inline] |
| fn hash(id: usize, bits: usize) -> usize { |
| id.wrapping_mul(0x9E37_79B9_7F4A_7C15) >> (64 - bits) |
| } |
| |
| impl<T: Send> ThreadLocal<T> { |
| /// Creates a new empty `ThreadLocal`. |
| pub fn new() -> ThreadLocal<T> { |
| let entry = TableEntry { |
| owner: AtomicUsize::new(0), |
| data: UnsafeCell::new(None), |
| }; |
| let table = Table { |
| entries: vec![entry; 2].into_boxed_slice(), |
| hash_bits: 1, |
| prev: None, |
| }; |
| ThreadLocal { |
| table: AtomicPtr::new(Box::into_raw(Box::new(table))), |
| lock: Mutex::new(0), |
| } |
| } |
| |
| /// Returns the element for the current thread, if it exists. |
| pub fn get(&self) -> Option<&T> { |
| let id = thread_id::get(); |
| self.get_fast(id) |
| } |
| |
| /// Returns the element for the current thread, or creates it if it doesn't |
| /// exist. |
| pub fn get_or<F>(&self, create: F) -> &T |
| where |
| F: FnOnce() -> T, |
| { |
| unsafe { |
| self.get_or_try(|| Ok::<T, ()>(create())) |
| .unchecked_unwrap_ok() |
| } |
| } |
| |
| /// Returns the element for the current thread, or creates it if it doesn't |
| /// exist. If `create` fails, that error is returned and no element is |
| /// added. |
| pub fn get_or_try<F, E>(&self, create: F) -> Result<&T, E> |
| where |
| F: FnOnce() -> Result<T, E>, |
| { |
| let id = thread_id::get(); |
| match self.get_fast(id) { |
| Some(x) => Ok(x), |
| None => Ok(self.insert(id, Box::new(create()?), true)), |
| } |
| } |
| |
| // Simple hash table lookup function |
| fn lookup(id: usize, table: &Table<T>) -> Option<&UnsafeCell<Option<Box<T>>>> { |
| // Because we use a Mutex to prevent concurrent modifications (but not |
| // reads) of the hash table, we can avoid any memory barriers here. No |
| // elements between our hash bucket and our value can have been modified |
| // since we inserted our thread-local value into the table. |
| for entry in table.entries.iter().cycle().skip(hash(id, table.hash_bits)) { |
| let owner = entry.owner.load(Ordering::Relaxed); |
| if owner == id { |
| return Some(&entry.data); |
| } |
| if owner == 0 { |
| return None; |
| } |
| } |
| unreachable!(); |
| } |
| |
| // Fast path: try to find our thread in the top-level hash table |
| fn get_fast(&self, id: usize) -> Option<&T> { |
| let table = unsafe { &*self.table.load(Ordering::Acquire) }; |
| match Self::lookup(id, table) { |
| Some(x) => unsafe { Some((*x.get()).as_ref().unchecked_unwrap()) }, |
| None => self.get_slow(id, table), |
| } |
| } |
| |
| // Slow path: try to find our thread in the other hash tables, and then |
| // move it to the top-level hash table. |
| #[cold] |
| fn get_slow(&self, id: usize, table_top: &Table<T>) -> Option<&T> { |
| let mut current = &table_top.prev; |
| while let Some(ref table) = *current { |
| if let Some(x) = Self::lookup(id, table) { |
| let data = unsafe { (*x.get()).take().unchecked_unwrap() }; |
| return Some(self.insert(id, data, false)); |
| } |
| current = &table.prev; |
| } |
| None |
| } |
| |
| #[cold] |
| fn insert(&self, id: usize, data: Box<T>, new: bool) -> &T { |
| // Lock the Mutex to ensure only a single thread is modify the hash |
| // table at once. |
| let mut count = self.lock.lock().unwrap(); |
| if new { |
| *count += 1; |
| } |
| let table_raw = self.table.load(Ordering::Relaxed); |
| let table = unsafe { &*table_raw }; |
| |
| // If the current top-level hash table is more than 75% full, add a new |
| // level with 2x the capacity. Elements will be moved up to the new top |
| // level table as they are accessed. |
| let table = if *count > table.entries.len() * 3 / 4 { |
| let entry = TableEntry { |
| owner: AtomicUsize::new(0), |
| data: UnsafeCell::new(None), |
| }; |
| let new_table = Box::into_raw(Box::new(Table { |
| entries: vec![entry; table.entries.len() * 2].into_boxed_slice(), |
| hash_bits: table.hash_bits + 1, |
| prev: unsafe { Some(Box::from_raw(table_raw)) }, |
| })); |
| self.table.store(new_table, Ordering::Release); |
| unsafe { &*new_table } |
| } else { |
| table |
| }; |
| |
| // Insert the new element into the top-level hash table |
| for entry in table.entries.iter().cycle().skip(hash(id, table.hash_bits)) { |
| let owner = entry.owner.load(Ordering::Relaxed); |
| if owner == 0 { |
| unsafe { |
| entry.owner.store(id, Ordering::Relaxed); |
| *entry.data.get() = Some(data); |
| return (*entry.data.get()).as_ref().unchecked_unwrap(); |
| } |
| } |
| if owner == id { |
| // This can happen if create() inserted a value into this |
| // ThreadLocal between our calls to get_fast() and insert(). We |
| // just return the existing value and drop the newly-allocated |
| // Box. |
| unsafe { |
| return (*entry.data.get()).as_ref().unchecked_unwrap(); |
| } |
| } |
| } |
| unreachable!(); |
| } |
| |
| fn raw_iter(&mut self) -> RawIter<T> { |
| RawIter { |
| remaining: *self.lock.get_mut().unwrap(), |
| index: 0, |
| table: self.table.load(Ordering::Relaxed), |
| } |
| } |
| |
| /// Returns a mutable iterator over the local values of all threads. |
| /// |
| /// Since this call borrows the `ThreadLocal` mutably, this operation can |
| /// be done safely---the mutable borrow statically guarantees no other |
| /// threads are currently accessing their associated values. |
| pub fn iter_mut(&mut self) -> IterMut<T> { |
| IterMut { |
| raw: self.raw_iter(), |
| marker: PhantomData, |
| } |
| } |
| |
| /// Removes all thread-specific values from the `ThreadLocal`, effectively |
| /// reseting it to its original state. |
| /// |
| /// Since this call borrows the `ThreadLocal` mutably, this operation can |
| /// be done safely---the mutable borrow statically guarantees no other |
| /// threads are currently accessing their associated values. |
| pub fn clear(&mut self) { |
| *self = ThreadLocal::new(); |
| } |
| } |
| |
| impl<T: Send> IntoIterator for ThreadLocal<T> { |
| type Item = T; |
| type IntoIter = IntoIter<T>; |
| |
| fn into_iter(mut self) -> IntoIter<T> { |
| IntoIter { |
| raw: self.raw_iter(), |
| _thread_local: self, |
| } |
| } |
| } |
| |
| impl<'a, T: Send + 'a> IntoIterator for &'a mut ThreadLocal<T> { |
| type Item = &'a mut T; |
| type IntoIter = IterMut<'a, T>; |
| |
| fn into_iter(self) -> IterMut<'a, T> { |
| self.iter_mut() |
| } |
| } |
| |
| impl<T: Send + Default> ThreadLocal<T> { |
| /// Returns the element for the current thread, or creates a default one if |
| /// it doesn't exist. |
| pub fn get_or_default(&self) -> &T { |
| self.get_or(Default::default) |
| } |
| } |
| |
| impl<T: Send + fmt::Debug> fmt::Debug for ThreadLocal<T> { |
| fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
| write!(f, "ThreadLocal {{ local_data: {:?} }}", self.get()) |
| } |
| } |
| |
| impl<T: Send + UnwindSafe> UnwindSafe for ThreadLocal<T> {} |
| |
| struct RawIter<T: Send> { |
| remaining: usize, |
| index: usize, |
| table: *const Table<T>, |
| } |
| |
| impl<T: Send> Iterator for RawIter<T> { |
| type Item = *mut Option<Box<T>>; |
| |
| fn next(&mut self) -> Option<*mut Option<Box<T>>> { |
| if self.remaining == 0 { |
| return None; |
| } |
| |
| loop { |
| let entries = unsafe { &(*self.table).entries[..] }; |
| while self.index < entries.len() { |
| let val = entries[self.index].data.get(); |
| self.index += 1; |
| if unsafe { (*val).is_some() } { |
| self.remaining -= 1; |
| return Some(val); |
| } |
| } |
| self.index = 0; |
| self.table = unsafe { &**(*self.table).prev.as_ref().unchecked_unwrap() }; |
| } |
| } |
| |
| fn size_hint(&self) -> (usize, Option<usize>) { |
| (self.remaining, Some(self.remaining)) |
| } |
| } |
| |
| /// Mutable iterator over the contents of a `ThreadLocal`. |
| pub struct IterMut<'a, T: Send + 'a> { |
| raw: RawIter<T>, |
| marker: PhantomData<&'a mut ThreadLocal<T>>, |
| } |
| |
| impl<'a, T: Send + 'a> Iterator for IterMut<'a, T> { |
| type Item = &'a mut T; |
| |
| fn next(&mut self) -> Option<&'a mut T> { |
| self.raw |
| .next() |
| .map(|x| unsafe { &mut **(*x).as_mut().unchecked_unwrap() }) |
| } |
| |
| fn size_hint(&self) -> (usize, Option<usize>) { |
| self.raw.size_hint() |
| } |
| } |
| |
| impl<'a, T: Send + 'a> ExactSizeIterator for IterMut<'a, T> {} |
| |
| /// An iterator that moves out of a `ThreadLocal`. |
| pub struct IntoIter<T: Send> { |
| raw: RawIter<T>, |
| _thread_local: ThreadLocal<T>, |
| } |
| |
| impl<T: Send> Iterator for IntoIter<T> { |
| type Item = T; |
| |
| fn next(&mut self) -> Option<T> { |
| self.raw |
| .next() |
| .map(|x| unsafe { *(*x).take().unchecked_unwrap() }) |
| } |
| |
| fn size_hint(&self) -> (usize, Option<usize>) { |
| self.raw.size_hint() |
| } |
| } |
| |
| impl<T: Send> ExactSizeIterator for IntoIter<T> {} |
| |
| #[cfg(test)] |
| mod tests { |
| use super::{CachedThreadLocal, ThreadLocal}; |
| use std::cell::RefCell; |
| use std::sync::atomic::AtomicUsize; |
| use std::sync::atomic::Ordering::Relaxed; |
| use std::sync::Arc; |
| use std::thread; |
| |
| fn make_create() -> Arc<dyn Fn() -> usize + Send + Sync> { |
| let count = AtomicUsize::new(0); |
| Arc::new(move || count.fetch_add(1, Relaxed)) |
| } |
| |
| #[test] |
| fn same_thread() { |
| let create = make_create(); |
| let mut tls = ThreadLocal::new(); |
| assert_eq!(None, tls.get()); |
| assert_eq!("ThreadLocal { local_data: None }", format!("{:?}", &tls)); |
| assert_eq!(0, *tls.get_or(|| create())); |
| assert_eq!(Some(&0), tls.get()); |
| assert_eq!(0, *tls.get_or(|| create())); |
| assert_eq!(Some(&0), tls.get()); |
| assert_eq!(0, *tls.get_or(|| create())); |
| assert_eq!(Some(&0), tls.get()); |
| assert_eq!("ThreadLocal { local_data: Some(0) }", format!("{:?}", &tls)); |
| tls.clear(); |
| assert_eq!(None, tls.get()); |
| } |
| |
| #[test] |
| fn same_thread_cached() { |
| let create = make_create(); |
| let mut tls = CachedThreadLocal::new(); |
| assert_eq!(None, tls.get()); |
| assert_eq!("ThreadLocal { local_data: None }", format!("{:?}", &tls)); |
| assert_eq!(0, *tls.get_or(|| create())); |
| assert_eq!(Some(&0), tls.get()); |
| assert_eq!(0, *tls.get_or(|| create())); |
| assert_eq!(Some(&0), tls.get()); |
| assert_eq!(0, *tls.get_or(|| create())); |
| assert_eq!(Some(&0), tls.get()); |
| assert_eq!("ThreadLocal { local_data: Some(0) }", format!("{:?}", &tls)); |
| tls.clear(); |
| assert_eq!(None, tls.get()); |
| } |
| |
| #[test] |
| fn different_thread() { |
| let create = make_create(); |
| let tls = Arc::new(ThreadLocal::new()); |
| assert_eq!(None, tls.get()); |
| assert_eq!(0, *tls.get_or(|| create())); |
| assert_eq!(Some(&0), tls.get()); |
| |
| let tls2 = tls.clone(); |
| let create2 = create.clone(); |
| thread::spawn(move || { |
| assert_eq!(None, tls2.get()); |
| assert_eq!(1, *tls2.get_or(|| create2())); |
| assert_eq!(Some(&1), tls2.get()); |
| }) |
| .join() |
| .unwrap(); |
| |
| assert_eq!(Some(&0), tls.get()); |
| assert_eq!(0, *tls.get_or(|| create())); |
| } |
| |
| #[test] |
| fn different_thread_cached() { |
| let create = make_create(); |
| let tls = Arc::new(CachedThreadLocal::new()); |
| assert_eq!(None, tls.get()); |
| assert_eq!(0, *tls.get_or(|| create())); |
| assert_eq!(Some(&0), tls.get()); |
| |
| let tls2 = tls.clone(); |
| let create2 = create.clone(); |
| thread::spawn(move || { |
| assert_eq!(None, tls2.get()); |
| assert_eq!(1, *tls2.get_or(|| create2())); |
| assert_eq!(Some(&1), tls2.get()); |
| }) |
| .join() |
| .unwrap(); |
| |
| assert_eq!(Some(&0), tls.get()); |
| assert_eq!(0, *tls.get_or(|| create())); |
| } |
| |
| #[test] |
| fn iter() { |
| let tls = Arc::new(ThreadLocal::new()); |
| tls.get_or(|| Box::new(1)); |
| |
| let tls2 = tls.clone(); |
| thread::spawn(move || { |
| tls2.get_or(|| Box::new(2)); |
| let tls3 = tls2.clone(); |
| thread::spawn(move || { |
| tls3.get_or(|| Box::new(3)); |
| }) |
| .join() |
| .unwrap(); |
| }) |
| .join() |
| .unwrap(); |
| |
| let mut tls = Arc::try_unwrap(tls).unwrap(); |
| let mut v = tls.iter_mut().map(|x| **x).collect::<Vec<i32>>(); |
| v.sort(); |
| assert_eq!(vec![1, 2, 3], v); |
| let mut v = tls.into_iter().map(|x| *x).collect::<Vec<i32>>(); |
| v.sort(); |
| assert_eq!(vec![1, 2, 3], v); |
| } |
| |
| #[test] |
| fn iter_cached() { |
| let tls = Arc::new(CachedThreadLocal::new()); |
| tls.get_or(|| Box::new(1)); |
| |
| let tls2 = tls.clone(); |
| thread::spawn(move || { |
| tls2.get_or(|| Box::new(2)); |
| let tls3 = tls2.clone(); |
| thread::spawn(move || { |
| tls3.get_or(|| Box::new(3)); |
| }) |
| .join() |
| .unwrap(); |
| }) |
| .join() |
| .unwrap(); |
| |
| let mut tls = Arc::try_unwrap(tls).unwrap(); |
| let mut v = tls.iter_mut().map(|x| **x).collect::<Vec<i32>>(); |
| v.sort(); |
| assert_eq!(vec![1, 2, 3], v); |
| let mut v = tls.into_iter().map(|x| *x).collect::<Vec<i32>>(); |
| v.sort(); |
| assert_eq!(vec![1, 2, 3], v); |
| } |
| |
| #[test] |
| fn is_sync() { |
| fn foo<T: Sync>() {} |
| foo::<ThreadLocal<String>>(); |
| foo::<ThreadLocal<RefCell<String>>>(); |
| foo::<CachedThreadLocal<String>>(); |
| foo::<CachedThreadLocal<RefCell<String>>>(); |
| } |
| } |