diff --git a/datafusion/physical-plan/src/aggregates/topk/hash_table.rs b/datafusion/physical-plan/src/aggregates/topk/hash_table.rs index 1fae507d90161..4a3f3ac258f9e 100644 --- a/datafusion/physical-plan/src/aggregates/topk/hash_table.rs +++ b/datafusion/physical-plan/src/aggregates/topk/hash_table.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! A wrapper around `hashbrown::RawTable` that allows entries to be tracked by index +//! A wrapper around `hashbrown::HashTable` that allows entries to be tracked by index use crate::aggregates::group_values::HashValue; use crate::aggregates::topk::heap::Comparable; @@ -29,7 +29,7 @@ use arrow::datatypes::{DataType, i256}; use datafusion_common::Result; use datafusion_common::exec_datafusion_err; use half::f16; -use hashbrown::raw::RawTable; +use hashbrown::hash_table::HashTable; use std::fmt::Debug; use std::sync::Arc; @@ -48,13 +48,17 @@ pub struct HashTableItem { pub heap_idx: usize, } -/// A custom wrapper around `hashbrown::RawTable` that: +/// A custom wrapper around `hashbrown::HashTable` that: /// 1. limits the number of entries to the top K /// 2. Allocates a capacity greater than top K to maintain a low-fill factor and prevent resizing /// 3. Tracks indexes to allow corresponding heap to refer to entries by index vs hash -/// 4. Catches resize events to allow the corresponding heap to update it's indexes struct TopKHashTable { - map: RawTable>, + map: HashTable, + // Store the actual items separately to allow for index-based access + store: Vec>>, + // Free index in the store for reuse + free_index: Option, + // The maximum number of entries allowed limit: usize, } @@ -62,25 +66,10 @@ struct TopKHashTable { pub trait ArrowHashTable { fn set_batch(&mut self, ids: ArrayRef); fn len(&self) -> usize; - // JUSTIFICATION - // Benefit: ~15% speedup + required to index into RawTable from binary heap - // Soundness: the caller must provide valid indexes - unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]); - // JUSTIFICATION - // Benefit: ~15% speedup + required to index into RawTable from binary heap - // Soundness: the caller must provide a valid index - unsafe fn heap_idx_at(&self, map_idx: usize) -> usize; - unsafe fn take_all(&mut self, indexes: Vec) -> ArrayRef; - - // JUSTIFICATION - // Benefit: ~15% speedup + required to index into RawTable from binary heap - // Soundness: the caller must provide valid indexes - unsafe fn find_or_insert( - &mut self, - row_idx: usize, - replace_idx: usize, - map: &mut Vec<(usize, usize)>, - ) -> (usize, bool); + fn update_heap_idx(&mut self, mapper: &[(usize, usize)]); + fn heap_idx_at(&self, map_idx: usize) -> usize; + fn take_all(&mut self, indexes: Vec) -> ArrayRef; + fn find_or_insert(&mut self, row_idx: usize, replace_idx: usize) -> (usize, bool); } // An implementation of ArrowHashTable for String keys @@ -130,91 +119,82 @@ impl ArrowHashTable for StringHashTable { self.map.len() } - unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) { - unsafe { - self.map.update_heap_idx(mapper); - } + fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) { + self.map.update_heap_idx(mapper); } - unsafe fn heap_idx_at(&self, map_idx: usize) -> usize { - unsafe { self.map.heap_idx_at(map_idx) } + fn heap_idx_at(&self, map_idx: usize) -> usize { + self.map.heap_idx_at(map_idx) } - unsafe fn take_all(&mut self, indexes: Vec) -> ArrayRef { - unsafe { - let ids = self.map.take_all(indexes); - match self.data_type { - DataType::Utf8 => Arc::new(StringArray::from(ids)), - DataType::LargeUtf8 => Arc::new(LargeStringArray::from(ids)), - DataType::Utf8View => Arc::new(StringViewArray::from(ids)), - _ => unreachable!(), - } + fn take_all(&mut self, indexes: Vec) -> ArrayRef { + let ids = self.map.take_all(indexes); + match self.data_type { + DataType::Utf8 => Arc::new(StringArray::from(ids)), + DataType::LargeUtf8 => Arc::new(LargeStringArray::from(ids)), + DataType::Utf8View => Arc::new(StringViewArray::from(ids)), + _ => unreachable!(), } } - unsafe fn find_or_insert( - &mut self, - row_idx: usize, - replace_idx: usize, - mapper: &mut Vec<(usize, usize)>, - ) -> (usize, bool) { - unsafe { - let id = match self.data_type { - DataType::Utf8 => { - let ids = self - .owned - .as_any() - .downcast_ref::() - .expect("Expected StringArray for DataType::Utf8"); - if ids.is_null(row_idx) { - None - } else { - Some(ids.value(row_idx)) - } + fn find_or_insert(&mut self, row_idx: usize, replace_idx: usize) -> (usize, bool) { + let id = match self.data_type { + DataType::Utf8 => { + let ids = self + .owned + .as_any() + .downcast_ref::() + .expect("Expected StringArray for DataType::Utf8"); + if ids.is_null(row_idx) { + None + } else { + Some(ids.value(row_idx)) } - DataType::LargeUtf8 => { - let ids = self - .owned - .as_any() - .downcast_ref::() - .expect("Expected LargeStringArray for DataType::LargeUtf8"); - if ids.is_null(row_idx) { - None - } else { - Some(ids.value(row_idx)) - } + } + DataType::LargeUtf8 => { + let ids = self + .owned + .as_any() + .downcast_ref::() + .expect("Expected LargeStringArray for DataType::LargeUtf8"); + if ids.is_null(row_idx) { + None + } else { + Some(ids.value(row_idx)) } - DataType::Utf8View => { - let ids = self - .owned - .as_any() - .downcast_ref::() - .expect("Expected StringViewArray for DataType::Utf8View"); - if ids.is_null(row_idx) { - None - } else { - Some(ids.value(row_idx)) - } + } + DataType::Utf8View => { + let ids = self + .owned + .as_any() + .downcast_ref::() + .expect("Expected StringViewArray for DataType::Utf8View"); + if ids.is_null(row_idx) { + None + } else { + Some(ids.value(row_idx)) } - _ => panic!("Unsupported data type"), - }; - - let hash = self.rnd.hash_one(id); - if let Some(map_idx) = self - .map - .find(hash, |mi| id == mi.as_ref().map(|id| id.as_str())) - { - return (map_idx, false); } + _ => panic!("Unsupported data type"), + }; - // we're full and this is a better value, so remove the worst - let heap_idx = self.map.remove_if_full(replace_idx); + // TODO: avoid double lookup by using entry API - // add the new group - let id = id.map(|id| id.to_string()); - let map_idx = self.map.insert(hash, id, heap_idx, mapper); - (map_idx, true) + let hash = self.rnd.hash_one(id); + if let Some(map_idx) = self + .map + .find(hash, |mi| id == mi.as_ref().map(|id| id.as_str())) + { + return (map_idx, false); } + + // we're full and this is a better value, so remove the worst + let heap_idx = self.map.remove_if_full(replace_idx); + + // add the new group + let id = id.map(|id| id.to_string()); + let map_idx = self.map.insert(hash, &id, heap_idx); + (map_idx, true) } } @@ -251,149 +231,137 @@ where self.map.len() } - unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) { - unsafe { - self.map.update_heap_idx(mapper); - } + fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) { + self.map.update_heap_idx(mapper); } - unsafe fn heap_idx_at(&self, map_idx: usize) -> usize { - unsafe { self.map.heap_idx_at(map_idx) } + fn heap_idx_at(&self, map_idx: usize) -> usize { + self.map.heap_idx_at(map_idx) } - unsafe fn take_all(&mut self, indexes: Vec) -> ArrayRef { - unsafe { - let ids = self.map.take_all(indexes); - let mut builder: PrimitiveBuilder = - PrimitiveArray::builder(ids.len()).with_data_type(self.kt.clone()); - for id in ids.into_iter() { - match id { - None => builder.append_null(), - Some(id) => builder.append_value(id), - } + fn take_all(&mut self, indexes: Vec) -> ArrayRef { + let ids = self.map.take_all(indexes); + let mut builder: PrimitiveBuilder = + PrimitiveArray::builder(ids.len()).with_data_type(self.kt.clone()); + for id in ids.into_iter() { + match id { + None => builder.append_null(), + Some(id) => builder.append_value(id), } - let ids = builder.finish(); - Arc::new(ids) } + let ids = builder.finish(); + Arc::new(ids) } - unsafe fn find_or_insert( - &mut self, - row_idx: usize, - replace_idx: usize, - mapper: &mut Vec<(usize, usize)>, - ) -> (usize, bool) { - unsafe { - let ids = self.owned.as_primitive::(); - let id: Option = if ids.is_null(row_idx) { - None - } else { - Some(ids.value(row_idx)) - }; - - let hash: u64 = id.hash(&self.rnd); - if let Some(map_idx) = self.map.find(hash, |mi| id == *mi) { - return (map_idx, false); - } - - // we're full and this is a better value, so remove the worst - let heap_idx = self.map.remove_if_full(replace_idx); + fn find_or_insert(&mut self, row_idx: usize, replace_idx: usize) -> (usize, bool) { + let ids = self.owned.as_primitive::(); + let id: Option = if ids.is_null(row_idx) { + None + } else { + Some(ids.value(row_idx)) + }; - // add the new group - let map_idx = self.map.insert(hash, id, heap_idx, mapper); - (map_idx, true) + let hash: u64 = id.hash(&self.rnd); + // TODO: avoid double lookup by using entry API + if let Some(map_idx) = self.map.find(hash, |mi| id == *mi) { + return (map_idx, false); } + + // we're full and this is a better value, so remove the worst + let heap_idx = self.map.remove_if_full(replace_idx); + + // add the new group + let map_idx = self.map.insert(hash, &id, heap_idx); + (map_idx, true) } } -impl TopKHashTable { +use hashbrown::hash_table::Entry; +impl TopKHashTable { pub fn new(limit: usize, capacity: usize) -> Self { Self { - map: RawTable::with_capacity(capacity), + map: HashTable::with_capacity(capacity), + store: Vec::with_capacity(capacity), + free_index: None, limit, } } pub fn find(&self, hash: u64, mut eq: impl FnMut(&ID) -> bool) -> Option { - let bucket = self.map.find(hash, |mi| eq(&mi.id))?; - // JUSTIFICATION - // Benefit: ~15% speedup + required to index into RawTable from binary heap - // Soundness: getting the index of a bucket we just found - let idx = unsafe { self.map.bucket_index(&bucket) }; - Some(idx) + let eq = |&idx: &usize| eq(&self.store[idx].as_ref().unwrap().id); + self.map.find(hash, eq).copied() } - pub unsafe fn heap_idx_at(&self, map_idx: usize) -> usize { - unsafe { - let bucket = self.map.bucket(map_idx); - bucket.as_ref().heap_idx - } + pub fn heap_idx_at(&self, map_idx: usize) -> usize { + self.store[map_idx].as_ref().unwrap().heap_idx } - pub unsafe fn remove_if_full(&mut self, replace_idx: usize) -> usize { - unsafe { - if self.map.len() >= self.limit { - self.map.erase(self.map.bucket(replace_idx)); - 0 // if full, always replace top node - } else { - self.map.len() // if we're not full, always append to end + pub fn remove_if_full(&mut self, replace_idx: usize) -> usize { + if self.map.len() >= self.limit { + let item_to_remove = self.store[replace_idx].as_ref().unwrap(); + let hash = item_to_remove.hash; + let id_to_remove = &item_to_remove.id; + + let eq = |&idx: &usize| self.store[idx].as_ref().unwrap().id == *id_to_remove; + let hasher = |idx: &usize| self.store[*idx].as_ref().unwrap().hash; + match self.map.entry(hash, eq, hasher) { + Entry::Occupied(entry) => { + let (removed_idx, _) = entry.remove(); + self.store[removed_idx] = None; + self.free_index = Some(removed_idx); + } + Entry::Vacant(_) => unreachable!(), } + 0 // if full, always replace top node + } else { + self.map.len() // if we're not full, always append to end } } - unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) { - unsafe { - for (m, h) in mapper { - self.map.bucket(*m).as_mut().heap_idx = *h - } + fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) { + for (m, h) in mapper { + self.store[*m].as_mut().unwrap().heap_idx = *h; } } - pub fn insert( - &mut self, - hash: u64, - id: ID, - heap_idx: usize, - mapper: &mut Vec<(usize, usize)>, - ) -> usize { - let mi = HashTableItem::new(hash, id, heap_idx); - let bucket = self.map.try_insert_no_grow(hash, mi); - let bucket = match bucket { - Ok(bucket) => bucket, - Err(new_item) => { - let bucket = self.map.insert(hash, new_item, |mi| mi.hash); - // JUSTIFICATION - // Benefit: ~15% speedup + required to index into RawTable from binary heap - // Soundness: we're getting indexes of buckets, not dereferencing them - unsafe { - for bucket in self.map.iter() { - let heap_idx = bucket.as_ref().heap_idx; - let map_idx = self.map.bucket_index(&bucket); - mapper.push((heap_idx, map_idx)); - } - } - bucket - } + pub fn insert(&mut self, hash: u64, id: &ID, heap_idx: usize) -> usize { + let mi = HashTableItem::new(hash, id.clone(), heap_idx); + let store_idx = if let Some(idx) = self.free_index.take() { + self.store[idx] = Some(mi); + idx + } else { + self.store.push(Some(mi)); + self.store.len() - 1 }; - // JUSTIFICATION - // Benefit: ~15% speedup + required to index into RawTable from binary heap - // Soundness: we're getting indexes of buckets, not dereferencing them - unsafe { self.map.bucket_index(&bucket) } + + let hasher = |idx: &usize| self.store[*idx].as_ref().unwrap().hash; + if self.map.len() == self.map.capacity() { + self.map.reserve(self.limit, hasher); + } + + let eq_fn = |idx: &usize| self.store[*idx].as_ref().unwrap().id == *id; + match self.map.entry(hash, eq_fn, hasher) { + Entry::Occupied(_) => unreachable!("Item should not exist"), + Entry::Vacant(vacant) => { + vacant.insert(store_idx); + } + } + store_idx } pub fn len(&self) -> usize { self.map.len() } - pub unsafe fn take_all(&mut self, idxs: Vec) -> Vec { - unsafe { - let ids = idxs - .into_iter() - .map(|idx| self.map.bucket(idx).as_ref().id.clone()) - .collect(); - self.map.clear(); - ids - } + pub fn take_all(&mut self, idxs: Vec) -> Vec { + let ids = idxs + .into_iter() + .map(|idx| self.store[idx].take().unwrap().id) + .collect(); + self.map.clear(); + self.store.clear(); + self.free_index = None; + ids } } @@ -471,11 +439,8 @@ mod tests { let dt = DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())); let mut ht = new_hash_table(1, dt.clone())?; ht.set_batch(Arc::new(ids)); - let mut mapper = vec![]; - let ids = unsafe { - ht.find_or_insert(0, 0, &mut mapper); - ht.take_all(vec![0]) - }; + ht.find_or_insert(0, 0); + let ids = ht.take_all(vec![0]); assert_eq!(ids.data_type(), &dt); Ok(()) @@ -486,26 +451,13 @@ mod tests { let mut heap_to_map = BTreeMap::::new(); let mut map = TopKHashTable::>::new(5, 3); for (heap_idx, id) in vec!["1", "2", "3", "4", "5"].into_iter().enumerate() { - let mut mapper = vec![]; let hash = heap_idx as u64; - let map_idx = map.insert(hash, Some(id.to_string()), heap_idx, &mut mapper); + let map_idx = map.insert(hash, &Some(id.to_string()), heap_idx); let _ = heap_to_map.insert(heap_idx, map_idx); - if heap_idx == 3 { - assert_eq!( - mapper, - vec![(0, 0), (1, 1), (2, 2), (3, 3)], - "Pass {heap_idx} resized incorrectly!" - ); - for (heap_idx, map_idx) in mapper { - let _ = heap_to_map.insert(heap_idx, map_idx); - } - } else { - assert_eq!(mapper, vec![], "Pass {heap_idx} should not have resized!"); - } } let (_heap_idxs, map_idxs): (Vec<_>, Vec<_>) = heap_to_map.into_iter().unzip(); - let ids = unsafe { map.take_all(map_idxs) }; + let ids = map.take_all(map_idxs); assert_eq!( format!("{ids:?}"), r#"[Some("1"), Some("2"), Some("3"), Some("4"), Some("5")]"# diff --git a/datafusion/physical-plan/src/aggregates/topk/heap.rs b/datafusion/physical-plan/src/aggregates/topk/heap.rs index abdf320ea39d8..b4569c3d0811d 100644 --- a/datafusion/physical-plan/src/aggregates/topk/heap.rs +++ b/datafusion/physical-plan/src/aggregates/topk/heap.rs @@ -72,7 +72,6 @@ pub trait ArrowHeap { fn set_batch(&mut self, vals: ArrayRef); fn is_worse(&self, idx: usize) -> bool; fn worst_map_idx(&self) -> usize; - fn renumber(&mut self, heap_to_map: &[(usize, usize)]); fn insert(&mut self, row_idx: usize, map_idx: usize, map: &mut Vec<(usize, usize)>); fn replace_if_better( &mut self, @@ -131,10 +130,6 @@ where self.heap.worst_map_idx() } - fn renumber(&mut self, heap_to_map: &[(usize, usize)]) { - self.heap.renumber(heap_to_map); - } - fn insert(&mut self, row_idx: usize, map_idx: usize, map: &mut Vec<(usize, usize)>) { let vals = self.batch.as_primitive::(); let new_val = vals.value(row_idx); @@ -268,14 +263,6 @@ impl TopKHeap { self.heapify_down(heap_idx, mapper); } - pub fn renumber(&mut self, heap_to_map: &[(usize, usize)]) { - for (heap_idx, map_idx) in heap_to_map.iter() { - if let Some(Some(hi)) = self.heap.get_mut(*heap_idx) { - hi.map_idx = *map_idx; - } - } - } - fn heapify_up(&mut self, mut idx: usize, mapper: &mut Vec<(usize, usize)>) { let desc = self.desc; while idx != 0 { @@ -608,29 +595,4 @@ mod tests { Ok(()) } - - #[test] - fn should_renumber() -> Result<()> { - let mut map = vec![]; - let mut heap = TopKHeap::new(10, false); - - heap.append_or_replace(1, 1, &mut map); - heap.append_or_replace(2, 2, &mut map); - - let actual = heap.to_string(); - assert_snapshot!(actual, @r" - val=2 idx=0, bucket=2 - └── val=1 idx=1, bucket=1 - "); - - let numbers = vec![(0, 1), (1, 2)]; - heap.renumber(numbers.as_slice()); - let actual = heap.to_string(); - assert_snapshot!(actual, @r" - val=2 idx=0, bucket=1 - └── val=1 idx=1, bucket=2 - "); - - Ok(()) - } } diff --git a/datafusion/physical-plan/src/aggregates/topk/priority_map.rs b/datafusion/physical-plan/src/aggregates/topk/priority_map.rs index fdff6b3a1a51c..8e093d213e784 100644 --- a/datafusion/physical-plan/src/aggregates/topk/priority_map.rs +++ b/datafusion/physical-plan/src/aggregates/topk/priority_map.rs @@ -63,40 +63,26 @@ impl PriorityMap { // handle new groups we haven't seen yet map.clear(); let replace_idx = self.heap.worst_map_idx(); - // JUSTIFICATION - // Benefit: ~15% speedup + required to index into RawTable from binary heap - // Soundness: replace_idx kept valid during resizes - let (map_idx, did_insert) = - unsafe { self.map.find_or_insert(row_idx, replace_idx, map) }; + + let (map_idx, did_insert) = self.map.find_or_insert(row_idx, replace_idx); if did_insert { - self.heap.renumber(map); - map.clear(); self.heap.insert(row_idx, map_idx, map); - // JUSTIFICATION - // Benefit: ~15% speedup + required to index into RawTable from binary heap - // Soundness: the map was created on the line above, so all the indexes should be valid - unsafe { self.map.update_heap_idx(map) }; + self.map.update_heap_idx(map); return Ok(()); }; // this is a value for an existing group map.clear(); - // JUSTIFICATION - // Benefit: ~15% speedup + required to index into RawTable from binary heap - // Soundness: map_idx was just found, so it is valid - let heap_idx = unsafe { self.map.heap_idx_at(map_idx) }; + let heap_idx = self.map.heap_idx_at(map_idx); self.heap.replace_if_better(heap_idx, row_idx, map); - // JUSTIFICATION - // Benefit: ~15% speedup + required to index into RawTable from binary heap - // Soundness: the index map was just built, so it will be valid - unsafe { self.map.update_heap_idx(map) }; + self.map.update_heap_idx(map); Ok(()) } pub fn emit(&mut self) -> Result> { let (vals, map_idxs) = self.heap.drain(); - let ids = unsafe { self.map.take_all(map_idxs) }; + let ids = self.map.take_all(map_idxs); Ok(vec![ids, vals]) }