1use std::borrow::Borrow;
2use std::hash::{Hash, Hasher};
3use std::iter;
4
5use either::Either;
6use hashbrown::hash_table::{self, Entry, HashTable};
7
8use crate::fx::FxHasher;
9use crate::sync::{CacheAligned, Lock, LockGuard, Mode, is_dyn_thread_safe};
10
11const SHARD_BITS: usize = 5;
15
16const SHARDS: usize = 1 << SHARD_BITS;
17
18pub enum Sharded<T> {
21 Single(Lock<T>),
22 Shards(Box<[CacheAligned<Lock<T>>; SHARDS]>),
23}
24
25impl<T: Default> Default for Sharded<T> {
26 #[inline]
27 fn default() -> Self {
28 Self::new(T::default)
29 }
30}
31
32impl<T> Sharded<T> {
33 #[inline]
34 pub fn new(mut value: impl FnMut() -> T) -> Self {
35 if is_dyn_thread_safe() {
36 return Sharded::Shards(Box::new(
37 [(); SHARDS].map(|()| CacheAligned(Lock::new(value()))),
38 ));
39 }
40
41 Sharded::Single(Lock::new(value()))
42 }
43
44 #[inline]
46 pub fn get_shard_by_value<K: Hash + ?Sized>(&self, val: &K) -> &Lock<T> {
47 match self {
48 Self::Single(single) => single,
49 Self::Shards(..) => self.get_shard_by_hash(make_hash(val)),
50 }
51 }
52
53 #[inline]
54 pub fn get_shard_by_hash(&self, hash: u64) -> &Lock<T> {
55 self.get_shard_by_index(get_shard_hash(hash))
56 }
57
58 #[inline]
59 pub fn get_shard_by_index(&self, i: usize) -> &Lock<T> {
60 match self {
61 Self::Single(single) => single,
62 Self::Shards(shards) => {
63 unsafe { &shards.get_unchecked(i & (SHARDS - 1)).0 }
65 }
66 }
67 }
68
69 #[inline]
71 #[track_caller]
72 pub fn lock_shard_by_value<K: Hash + ?Sized>(&self, val: &K) -> LockGuard<'_, T> {
73 match self {
74 Self::Single(single) => {
75 unsafe { single.lock_assume(Mode::NoSync) }
81 }
82 Self::Shards(..) => self.lock_shard_by_hash(make_hash(val)),
83 }
84 }
85
86 #[inline]
87 #[track_caller]
88 pub fn lock_shard_by_hash(&self, hash: u64) -> LockGuard<'_, T> {
89 self.lock_shard_by_index(get_shard_hash(hash))
90 }
91
92 #[inline]
93 #[track_caller]
94 pub fn lock_shard_by_index(&self, i: usize) -> LockGuard<'_, T> {
95 match self {
96 Self::Single(single) => {
97 unsafe { single.lock_assume(Mode::NoSync) }
103 }
104 Self::Shards(shards) => {
105 unsafe { shards.get_unchecked(i & (SHARDS - 1)).0.lock_assume(Mode::Sync) }
113 }
114 }
115 }
116
117 #[inline]
118 pub fn lock_shards(&self) -> impl Iterator<Item = LockGuard<'_, T>> {
119 match self {
120 Self::Single(single) => Either::Left(iter::once(single.lock())),
121 Self::Shards(shards) => Either::Right(shards.iter().map(|shard| shard.0.lock())),
122 }
123 }
124
125 #[inline]
126 pub fn try_lock_shards(&self) -> impl Iterator<Item = Option<LockGuard<'_, T>>> {
127 match self {
128 Self::Single(single) => Either::Left(iter::once(single.try_lock())),
129 Self::Shards(shards) => Either::Right(shards.iter().map(|shard| shard.0.try_lock())),
130 }
131 }
132}
133
134#[inline]
135pub fn shards() -> usize {
136 if is_dyn_thread_safe() {
137 return SHARDS;
138 }
139
140 1
141}
142
143pub type ShardedHashMap<K, V> = Sharded<hash_table::HashTable<(K, V)>>;
144
145impl<K: Eq, V> ShardedHashMap<K, V> {
146 pub fn with_capacity(cap: usize) -> Self {
147 Self::new(|| HashTable::with_capacity(cap))
148 }
149 pub fn len(&self) -> usize {
150 self.lock_shards().map(|shard| shard.len()).sum()
151 }
152}
153
154impl<K: Eq + Hash, V> ShardedHashMap<K, V> {
155 #[inline]
156 pub fn get<Q>(&self, key: &Q) -> Option<V>
157 where
158 K: Borrow<Q>,
159 Q: Hash + Eq,
160 V: Clone,
161 {
162 let hash = make_hash(key);
163 let shard = self.lock_shard_by_hash(hash);
164 let (_, value) = shard.find(hash, |(k, _)| k.borrow() == key)?;
165 Some(value.clone())
166 }
167
168 #[inline]
169 pub fn get_or_insert_with(&self, key: K, default: impl FnOnce() -> V) -> V
170 where
171 V: Copy,
172 {
173 let hash = make_hash(&key);
174 let mut shard = self.lock_shard_by_hash(hash);
175
176 match table_entry(&mut shard, hash, &key) {
177 Entry::Occupied(e) => e.get().1,
178 Entry::Vacant(e) => {
179 let value = default();
180 e.insert((key, value));
181 value
182 }
183 }
184 }
185
186 #[inline]
192 pub fn insert_unique(&self, key: K, value: V) {
193 let hash = make_hash(&key);
194 let mut shard = self.lock_shard_by_hash(hash);
195
196 cfg_select! {
197 debug_assertions => match table_entry(&mut shard, hash, &key) {
198 Entry::Occupied(_) => {
199 {
::core::panicking::panic_fmt(format_args!("tried to insert key that\'s already present"));
};panic!("tried to insert key that's already present");
200 }
201 Entry::Vacant(e) => {
202 e.insert((key, value));
203 }
204 }
205 _ => {
206 shard.insert_unique(hash, (key, value), |(k, _)| make_hash(k));
207 }
208 }
209 }
210}
211
212impl<K: Eq + Hash + Copy> ShardedHashMap<K, ()> {
213 #[inline]
214 pub fn intern_ref<Q: ?Sized>(&self, value: &Q, make: impl FnOnce() -> K) -> K
215 where
216 K: Borrow<Q>,
217 Q: Hash + Eq,
218 {
219 let hash = make_hash(value);
220 let mut shard = self.lock_shard_by_hash(hash);
221
222 match table_entry(&mut shard, hash, value) {
223 Entry::Occupied(e) => e.get().0,
224 Entry::Vacant(e) => {
225 let v = make();
226 e.insert((v, ()));
227 v
228 }
229 }
230 }
231
232 #[inline]
233 pub fn intern<Q>(&self, value: Q, make: impl FnOnce(Q) -> K) -> K
234 where
235 K: Borrow<Q>,
236 Q: Hash + Eq,
237 {
238 let hash = make_hash(&value);
239 let mut shard = self.lock_shard_by_hash(hash);
240
241 match table_entry(&mut shard, hash, &value) {
242 Entry::Occupied(e) => e.get().0,
243 Entry::Vacant(e) => {
244 let v = make(value);
245 e.insert((v, ()));
246 v
247 }
248 }
249 }
250}
251
252pub trait IntoPointer {
253 fn into_pointer(&self) -> *const ();
255}
256
257impl<K: Eq + Hash + Copy + IntoPointer> ShardedHashMap<K, ()> {
258 pub fn contains_pointer_to<T: Hash + IntoPointer>(&self, value: &T) -> bool {
259 let hash = make_hash(&value);
260 let shard = self.lock_shard_by_hash(hash);
261 let value = value.into_pointer();
262 shard.find(hash, |(k, ())| k.into_pointer() == value).is_some()
263 }
264}
265
266#[inline]
267pub fn make_hash<K: Hash + ?Sized>(val: &K) -> u64 {
268 let mut state = FxHasher::default();
269 val.hash(&mut state);
270 state.finish()
271}
272
273#[inline]
274fn table_entry<'a, K, V, Q>(
275 table: &'a mut HashTable<(K, V)>,
276 hash: u64,
277 key: &Q,
278) -> Entry<'a, (K, V)>
279where
280 K: Hash + Borrow<Q>,
281 Q: ?Sized + Eq,
282{
283 table.entry(hash, move |(k, _)| k.borrow() == key, |(k, _)| make_hash(k))
284}
285
286#[inline]
292fn get_shard_hash(hash: u64) -> usize {
293 let hash_len = size_of::<usize>();
294 (hash >> (hash_len * 8 - 7 - SHARD_BITS)) as usize
297}