Skip to main content

rustc_data_structures/
sharded.rs

1use std::borrow::Borrow;
2use std::hash::{Hash, Hasher};
3use std::iter;
4
5use either::Either;
6use hashbrown::hash_table::{self, Entry, HashTable};
7
8use crate::fx::FxHasher;
9use crate::sync::{CacheAligned, Lock, LockGuard, Mode, is_dyn_thread_safe};
10
11// 32 shards is sufficient to reduce contention on an 8-core Ryzen 7 1700,
12// but this should be tested on higher core count CPUs. How the `Sharded` type gets used
13// may also affect the ideal number of shards.
14const SHARD_BITS: usize = 5;
15
16const SHARDS: usize = 1 << SHARD_BITS;
17
18/// An array of cache-line aligned inner locked structures with convenience methods.
19/// A single field is used when the compiler uses only one thread.
20pub enum Sharded<T> {
21    Single(Lock<T>),
22    Shards(Box<[CacheAligned<Lock<T>>; SHARDS]>),
23}
24
25impl<T: Default> Default for Sharded<T> {
26    #[inline]
27    fn default() -> Self {
28        Self::new(T::default)
29    }
30}
31
32impl<T> Sharded<T> {
33    #[inline]
34    pub fn new(mut value: impl FnMut() -> T) -> Self {
35        if is_dyn_thread_safe() {
36            return Sharded::Shards(Box::new(
37                [(); SHARDS].map(|()| CacheAligned(Lock::new(value()))),
38            ));
39        }
40
41        Sharded::Single(Lock::new(value()))
42    }
43
44    /// The shard is selected by hashing `val` with `FxHasher`.
45    #[inline]
46    pub fn get_shard_by_value<K: Hash + ?Sized>(&self, val: &K) -> &Lock<T> {
47        match self {
48            Self::Single(single) => single,
49            Self::Shards(..) => self.get_shard_by_hash(make_hash(val)),
50        }
51    }
52
53    #[inline]
54    pub fn get_shard_by_hash(&self, hash: u64) -> &Lock<T> {
55        self.get_shard_by_index(get_shard_hash(hash))
56    }
57
58    #[inline]
59    pub fn get_shard_by_index(&self, i: usize) -> &Lock<T> {
60        match self {
61            Self::Single(single) => single,
62            Self::Shards(shards) => {
63                // SAFETY: The index gets ANDed with the shard mask, ensuring it is always inbounds.
64                unsafe { &shards.get_unchecked(i & (SHARDS - 1)).0 }
65            }
66        }
67    }
68
69    /// The shard is selected by hashing `val` with `FxHasher`.
70    #[inline]
71    #[track_caller]
72    pub fn lock_shard_by_value<K: Hash + ?Sized>(&self, val: &K) -> LockGuard<'_, T> {
73        match self {
74            Self::Single(single) => {
75                // Synchronization is disabled so use the `lock_assume_no_sync` method optimized
76                // for that case.
77
78                // SAFETY: We know `is_dyn_thread_safe` was false when creating the lock thus
79                // `might_be_dyn_thread_safe` was also false.
80                unsafe { single.lock_assume(Mode::NoSync) }
81            }
82            Self::Shards(..) => self.lock_shard_by_hash(make_hash(val)),
83        }
84    }
85
86    #[inline]
87    #[track_caller]
88    pub fn lock_shard_by_hash(&self, hash: u64) -> LockGuard<'_, T> {
89        self.lock_shard_by_index(get_shard_hash(hash))
90    }
91
92    #[inline]
93    #[track_caller]
94    pub fn lock_shard_by_index(&self, i: usize) -> LockGuard<'_, T> {
95        match self {
96            Self::Single(single) => {
97                // Synchronization is disabled so use the `lock_assume_no_sync` method optimized
98                // for that case.
99
100                // SAFETY: We know `is_dyn_thread_safe` was false when creating the lock thus
101                // `might_be_dyn_thread_safe` was also false.
102                unsafe { single.lock_assume(Mode::NoSync) }
103            }
104            Self::Shards(shards) => {
105                // Synchronization is enabled so use the `lock_assume_sync` method optimized
106                // for that case.
107
108                // SAFETY (get_unchecked): The index gets ANDed with the shard mask, ensuring it is
109                // always inbounds.
110                // SAFETY (lock_assume_sync): We know `is_dyn_thread_safe` was true when creating
111                // the lock thus `might_be_dyn_thread_safe` was also true.
112                unsafe { shards.get_unchecked(i & (SHARDS - 1)).0.lock_assume(Mode::Sync) }
113            }
114        }
115    }
116
117    #[inline]
118    pub fn lock_shards(&self) -> impl Iterator<Item = LockGuard<'_, T>> {
119        match self {
120            Self::Single(single) => Either::Left(iter::once(single.lock())),
121            Self::Shards(shards) => Either::Right(shards.iter().map(|shard| shard.0.lock())),
122        }
123    }
124
125    #[inline]
126    pub fn try_lock_shards(&self) -> impl Iterator<Item = Option<LockGuard<'_, T>>> {
127        match self {
128            Self::Single(single) => Either::Left(iter::once(single.try_lock())),
129            Self::Shards(shards) => Either::Right(shards.iter().map(|shard| shard.0.try_lock())),
130        }
131    }
132}
133
134#[inline]
135pub fn shards() -> usize {
136    if is_dyn_thread_safe() {
137        return SHARDS;
138    }
139
140    1
141}
142
143pub type ShardedHashMap<K, V> = Sharded<hash_table::HashTable<(K, V)>>;
144
145impl<K: Eq, V> ShardedHashMap<K, V> {
146    pub fn with_capacity(cap: usize) -> Self {
147        Self::new(|| HashTable::with_capacity(cap))
148    }
149    pub fn len(&self) -> usize {
150        self.lock_shards().map(|shard| shard.len()).sum()
151    }
152}
153
154impl<K: Eq + Hash, V> ShardedHashMap<K, V> {
155    #[inline]
156    pub fn get<Q>(&self, key: &Q) -> Option<V>
157    where
158        K: Borrow<Q>,
159        Q: Hash + Eq,
160        V: Clone,
161    {
162        let hash = make_hash(key);
163        let shard = self.lock_shard_by_hash(hash);
164        let (_, value) = shard.find(hash, |(k, _)| k.borrow() == key)?;
165        Some(value.clone())
166    }
167
168    #[inline]
169    pub fn get_or_insert_with(&self, key: K, default: impl FnOnce() -> V) -> V
170    where
171        V: Copy,
172    {
173        let hash = make_hash(&key);
174        let mut shard = self.lock_shard_by_hash(hash);
175
176        match table_entry(&mut shard, hash, &key) {
177            Entry::Occupied(e) => e.get().1,
178            Entry::Vacant(e) => {
179                let value = default();
180                e.insert((key, value));
181                value
182            }
183        }
184    }
185
186    /// Insert value into the [`ShardedHashMap`] with unique key.
187    ///
188    /// This function panics if debug_assertions are enabled and uniqueness is violated.
189    /// If uniqueness is violated but debug_assertions are disabled then lookups will arbitrarily
190    /// return one of the inserted elements.
191    #[inline]
192    pub fn insert_unique(&self, key: K, value: V) {
193        let hash = make_hash(&key);
194        let mut shard = self.lock_shard_by_hash(hash);
195
196        cfg_select! {
197            debug_assertions => match table_entry(&mut shard, hash, &key) {
198                Entry::Occupied(_) => {
199                    {
    ::core::panicking::panic_fmt(format_args!("tried to insert key that\'s already present"));
};panic!("tried to insert key that's already present");
200                }
201                Entry::Vacant(e) => {
202                    e.insert((key, value));
203                }
204            }
205            _ => {
206                shard.insert_unique(hash, (key, value), |(k, _)| make_hash(k));
207            }
208        }
209    }
210}
211
212impl<K: Eq + Hash + Copy> ShardedHashMap<K, ()> {
213    #[inline]
214    pub fn intern_ref<Q: ?Sized>(&self, value: &Q, make: impl FnOnce() -> K) -> K
215    where
216        K: Borrow<Q>,
217        Q: Hash + Eq,
218    {
219        let hash = make_hash(value);
220        let mut shard = self.lock_shard_by_hash(hash);
221
222        match table_entry(&mut shard, hash, value) {
223            Entry::Occupied(e) => e.get().0,
224            Entry::Vacant(e) => {
225                let v = make();
226                e.insert((v, ()));
227                v
228            }
229        }
230    }
231
232    #[inline]
233    pub fn intern<Q>(&self, value: Q, make: impl FnOnce(Q) -> K) -> K
234    where
235        K: Borrow<Q>,
236        Q: Hash + Eq,
237    {
238        let hash = make_hash(&value);
239        let mut shard = self.lock_shard_by_hash(hash);
240
241        match table_entry(&mut shard, hash, &value) {
242            Entry::Occupied(e) => e.get().0,
243            Entry::Vacant(e) => {
244                let v = make(value);
245                e.insert((v, ()));
246                v
247            }
248        }
249    }
250}
251
252pub trait IntoPointer {
253    /// Returns a pointer which outlives `self`.
254    fn into_pointer(&self) -> *const ();
255}
256
257impl<K: Eq + Hash + Copy + IntoPointer> ShardedHashMap<K, ()> {
258    pub fn contains_pointer_to<T: Hash + IntoPointer>(&self, value: &T) -> bool {
259        let hash = make_hash(&value);
260        let shard = self.lock_shard_by_hash(hash);
261        let value = value.into_pointer();
262        shard.find(hash, |(k, ())| k.into_pointer() == value).is_some()
263    }
264}
265
266#[inline]
267pub fn make_hash<K: Hash + ?Sized>(val: &K) -> u64 {
268    let mut state = FxHasher::default();
269    val.hash(&mut state);
270    state.finish()
271}
272
273#[inline]
274fn table_entry<'a, K, V, Q>(
275    table: &'a mut HashTable<(K, V)>,
276    hash: u64,
277    key: &Q,
278) -> Entry<'a, (K, V)>
279where
280    K: Hash + Borrow<Q>,
281    Q: ?Sized + Eq,
282{
283    table.entry(hash, move |(k, _)| k.borrow() == key, |(k, _)| make_hash(k))
284}
285
286/// Get a shard with a pre-computed hash value. If `get_shard_by_value` is
287/// ever used in combination with `get_shard_by_hash` on a single `Sharded`
288/// instance, then `hash` must be computed with `FxHasher`. Otherwise,
289/// `hash` can be computed with any hasher, so long as that hasher is used
290/// consistently for each `Sharded` instance.
291#[inline]
292fn get_shard_hash(hash: u64) -> usize {
293    let hash_len = size_of::<usize>();
294    // Ignore the top 7 bits as hashbrown uses these and get the next SHARD_BITS highest bits.
295    // hashbrown also uses the lowest bits, so we can't use those
296    (hash >> (hash_len * 8 - 7 - SHARD_BITS)) as usize
297}