From 7c85ee6b90332b45ca1d8bee399a399b65e08c4f Mon Sep 17 00:00:00 2001 From: cutecutecat Date: Mon, 12 May 2025 13:18:18 +0800 Subject: [PATCH 1/2] feat: support prefilter Signed-off-by: cutecutecat --- .github/workflows/check.yml | 16 +- crates/algorithm/src/fast_heap.rs | 3 + crates/algorithm/src/insert.rs | 36 +- crates/algorithm/src/lib.rs | 79 +++- src/index/algorithm.rs | 20 +- src/index/am/mod.rs | 56 ++- src/index/gucs.rs | 14 +- src/index/scanners/default.rs | 424 +++++++------------ src/index/scanners/maxsim.rs | 355 ++++++++-------- src/index/scanners/mod.rs | 4 + tests/{logic => general}/distance.slt | 0 tests/{logic => general}/external_build.slt | 3 + tests/{logic => general}/index.slt | 0 tests/{logic => general}/issue427.slt | 0 tests/{logic => general}/multivector.slt | 0 tests/{logic => general}/null.fail | 0 tests/{logic => general}/partition.slt | 3 + tests/{logic => general}/pin.slt | 0 tests/{logic => general}/pushdown_plan.slt | 0 tests/{logic => general}/pushdown_range.slt | 0 tests/{logic => general}/reindex.slt | 0 tests/general/rerank_in_index.slt | 104 +++++ tests/{logic => general}/rerank_in_table.slt | 7 +- tests/{logic => general}/vector.slt | 0 tests/pg16/filter_rerank_in_index.slt | 101 +++++ tests/pg16/filter_rerank_in_table.slt | 101 +++++ tests/pg17/filter_rerank_in_index.slt | 101 +++++ tests/pg17/filter_rerank_in_table.slt | 101 +++++ 28 files changed, 1067 insertions(+), 461 deletions(-) rename tests/{logic => general}/distance.slt (100%) rename tests/{logic => general}/external_build.slt (95%) rename tests/{logic => general}/index.slt (100%) rename tests/{logic => general}/issue427.slt (100%) rename tests/{logic => general}/multivector.slt (100%) rename tests/{logic => general}/null.fail (100%) rename tests/{logic => general}/partition.slt (95%) rename tests/{logic => general}/pin.slt (100%) rename tests/{logic => general}/pushdown_plan.slt (100%) rename tests/{logic => general}/pushdown_range.slt (100%) rename tests/{logic => general}/reindex.slt (100%) create mode 100644 tests/general/rerank_in_index.slt rename tests/{logic => general}/rerank_in_table.slt (94%) rename tests/{logic => general}/vector.slt (100%) create mode 100644 tests/pg16/filter_rerank_in_index.slt create mode 100644 tests/pg16/filter_rerank_in_table.slt create mode 100644 tests/pg17/filter_rerank_in_index.slt create mode 100644 tests/pg17/filter_rerank_in_table.slt diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index 741de314..97083aa0 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -235,7 +235,21 @@ jobs: run: | sudo systemctl start postgresql psql -c 'CREATE EXTENSION IF NOT EXISTS vchord CASCADE;' - sqllogictest --db $USER --user $USER './tests/**/*.slt' + sqllogictest --db $USER --user $USER './tests/general/*.slt' + + - name: Sqllogictest(PostgrSQL 17 features) + if: matrix.version == '17' + run: | + sudo systemctl start postgresql + psql -c 'CREATE EXTENSION IF NOT EXISTS vchord CASCADE;' + sqllogictest --db $USER --user $USER './tests/pg17/*.slt' + + - name: Sqllogictest(PostgrSQL 16 features) + if: matrix.version == '16' + run: | + sudo systemctl start postgresql + psql -c 'CREATE EXTENSION IF NOT EXISTS vchord CASCADE;' + sqllogictest --db $USER --user $USER './tests/pg16/*.slt' - name: Package env: diff --git a/crates/algorithm/src/fast_heap.rs b/crates/algorithm/src/fast_heap.rs index 18ef21b2..fd7db993 100644 --- a/crates/algorithm/src/fast_heap.rs +++ b/crates/algorithm/src/fast_heap.rs @@ -76,6 +76,9 @@ impl From> for FastHeap { impl Sequence for FastHeap { type Item = T; type Inner = std::vec::IntoIter; + fn peek(&mut self) -> Option<&T> { + >::peek(self) + } fn next(&mut self) -> Option { self.pop() } diff --git a/crates/algorithm/src/insert.rs b/crates/algorithm/src/insert.rs index 4194eb54..b2e08e70 100644 --- a/crates/algorithm/src/insert.rs +++ b/crates/algorithm/src/insert.rs @@ -30,25 +30,49 @@ type Item<'b> = ( AlwaysEqual<&'b mut (u32, u16, &'b mut [u32])>, ); -pub fn insert<'r, 'b: 'r, R: RelationRead + RelationWrite, O: Operator>( +pub fn insert_vector( + index: &R, + payload: NonZero, + vector: &O::Vector, +) -> (Vec, u16) { + // `insert_vector` returns a tuple `(list, head)` which will be used in `insert_index` later: + // - `list`: Represents the list of elements to be inserted into the index. + // - `head`: Represents the head of the list, used as a starting point for insertion. + let meta_guard = index.read(0); + let meta_bytes = meta_guard.get(1).expect("data corruption"); + let meta_tuple = MetaTuple::deserialize_ref(meta_bytes); + let dims = meta_tuple.dims(); + let rerank_in_heap = meta_tuple.rerank_in_heap(); + assert_eq!(dims, vector.as_borrowed().dims(), "unmatched dimensions"); + let vectors_first = meta_tuple.vectors_first(); + drop(meta_guard); + + if !rerank_in_heap { + vectors::append::(index, vectors_first, vector.as_borrowed(), payload) + } else { + (Vec::new(), 0) + } +} + +pub fn insert_index<'r, 'b: 'r, R: RelationRead + RelationWrite, O: Operator>( index: &'r R, payload: NonZero, vector: O::Vector, bump: &'b impl Bump, mut prefetch_h1_vectors: impl PrefetcherHeapFamily<'r, R>, + list: Vec, + head: u16, ) { let meta_guard = index.read(0); let meta_bytes = meta_guard.get(1).expect("data corruption"); let meta_tuple = MetaTuple::deserialize_ref(meta_bytes); let dims = meta_tuple.dims(); let is_residual = meta_tuple.is_residual(); - let rerank_in_heap = meta_tuple.rerank_in_heap(); let height_of_root = meta_tuple.height_of_root(); assert_eq!(dims, vector.as_borrowed().dims(), "unmatched dimensions"); let root_prefetch = meta_tuple.root_prefetch().to_vec(); let root_head = meta_tuple.root_head(); let root_first = meta_tuple.root_first(); - let vectors_first = meta_tuple.vectors_first(); drop(meta_guard); let default_block_lut = if !is_residual { @@ -57,12 +81,6 @@ pub fn insert<'r, 'b: 'r, R: RelationRead + RelationWrite, O: Operator>( None }; - let (list, head) = if !rerank_in_heap { - vectors::append::(index, vectors_first, vector.as_borrowed(), payload) - } else { - (Vec::new(), 0) - }; - type State = (u32, Option<::Vector>); let mut state: State = { if is_residual { diff --git a/crates/algorithm/src/lib.rs b/crates/algorithm/src/lib.rs index cc37b57b..b13af61d 100644 --- a/crates/algorithm/src/lib.rs +++ b/crates/algorithm/src/lib.rs @@ -42,7 +42,7 @@ pub use bulkdelete::bulkdelete; pub use cache::cache; pub use cost::cost; pub use fast_heap::FastHeap; -pub use insert::insert; +pub use insert::{insert_index, insert_vector}; pub use maintain::maintain; pub use prefetcher::*; pub use prewarm::prewarm; @@ -219,17 +219,76 @@ impl<'b, T, A, B> Fetch for (T, AlwaysEqual<&'b mut (A, B, &'b mut [u32])>) { } } +pub struct Filter { + pub iter: S, + pub predicate: P, +} + +impl bool> Sequence for Filter { + type Item = S::Item; + type Inner = S::Inner; + + fn peek(&mut self) -> Option<&Self::Item> { + self.iter.peek() + } + + fn next(&mut self) -> Option { + loop { + let item = self.iter.peek()?; + if (self.predicate)(item) { + return self.iter.next(); + } else { + self.iter.next(); + continue; + } + } + } + + fn next_if(&mut self, predicate: impl FnOnce(&Self::Item) -> bool) -> Option { + loop { + let item = self.iter.peek()?; + if (self.predicate)(item) { + return match predicate(item) { + true => self.iter.next(), + false => None, + }; + } else { + self.iter.next(); + continue; + } + } + } + + fn into_inner(self) -> Self::Inner { + self.iter.into_inner() + } +} + pub trait Sequence { type Item; type Inner: Iterator; + fn peek(&mut self) -> Option<&Self::Item>; fn next(&mut self) -> Option; fn next_if(&mut self, predicate: impl FnOnce(&Self::Item) -> bool) -> Option; fn into_inner(self) -> Self::Inner; + fn filter

(self, predicate: P) -> Filter + where + Self: Sized, + P: FnMut(&Self::Item) -> bool, + { + Filter { + iter: self, + predicate, + } + } } impl Sequence for BinaryHeap { type Item = T; type Inner = std::vec::IntoIter; + fn peek(&mut self) -> Option<&T> { + >::peek(self) + } fn next(&mut self) -> Option { self.pop() } @@ -245,6 +304,9 @@ impl Sequence for BinaryHeap { impl Sequence for Peekable { type Item = I::Item; type Inner = Peekable; + fn peek(&mut self) -> Option<&I::Item> { + Peekable::peek(self) + } fn next(&mut self) -> Option { Iterator::next(self) } @@ -255,3 +317,18 @@ impl Sequence for Peekable { self } } + +pub fn seq_filter( + heap: impl Sequence, + prefilter: bool, + filter: F, +) -> impl Sequence +where + F: Fn(&T) -> bool, + T: Ord, +{ + heap.filter(move |t| match prefilter { + true => filter(t), + false => true, + }) +} diff --git a/src/index/algorithm.rs b/src/index/algorithm.rs index 95959877..ed207db0 100644 --- a/src/index/algorithm.rs +++ b/src/index/algorithm.rs @@ -322,42 +322,54 @@ pub fn insert( match (vector, opfamily.distance_kind()) { (OwnedVector::Vecf32(vector), DistanceKind::L2) => { assert!(opfamily.vector_kind() == VectorKind::Vecf32); - algorithm::insert::<_, Op, L2>>( + let (list, head) = insert_vector::<_, Op, L2>>(index, payload, &vector); + insert_index::<_, Op, L2>>( index, payload, RandomProject::project(vector.as_borrowed()), &bump, make_h1_plain_prefetcher, + list, + head, ) } (OwnedVector::Vecf32(vector), DistanceKind::Dot) => { assert!(opfamily.vector_kind() == VectorKind::Vecf32); - algorithm::insert::<_, Op, Dot>>( + let (list, head) = insert_vector::<_, Op, Dot>>(index, payload, &vector); + insert_index::<_, Op, Dot>>( index, payload, RandomProject::project(vector.as_borrowed()), &bump, make_h1_plain_prefetcher, + list, + head, ) } (OwnedVector::Vecf16(vector), DistanceKind::L2) => { assert!(opfamily.vector_kind() == VectorKind::Vecf16); - algorithm::insert::<_, Op, L2>>( + let (list, head) = insert_vector::<_, Op, L2>>(index, payload, &vector); + insert_index::<_, Op, L2>>( index, payload, RandomProject::project(vector.as_borrowed()), &bump, make_h1_plain_prefetcher, + list, + head, ) } (OwnedVector::Vecf16(vector), DistanceKind::Dot) => { assert!(opfamily.vector_kind() == VectorKind::Vecf16); - algorithm::insert::<_, Op, Dot>>( + let (list, head) = insert_vector::<_, Op, Dot>>(index, payload, &vector); + insert_index::<_, Op, Dot>>( index, payload, RandomProject::project(vector.as_borrowed()), &bump, make_h1_plain_prefetcher, + list, + head, ) } } diff --git a/src/index/am/mod.rs b/src/index/am/mod.rs index a6607f14..cc2d89a0 100644 --- a/src/index/am/mod.rs +++ b/src/index/am/mod.rs @@ -15,7 +15,7 @@ pub mod am_build; use super::algorithm::BumpAlloc; -use super::gucs::prererank_filtering; +use super::gucs::prefilter; use crate::index::gucs; use crate::index::lazy_cell::LazyCell; use crate::index::opclass::{Opfamily, opfamily}; @@ -571,6 +571,13 @@ impl Drop for HeapFetcher { } impl SearchFetcher for HeapFetcher { + /// Fetches data associated with the given `key` passes the filter criteria. + /// + /// # Parameters + /// - `key`: A 3-element array representing the key to be checked. + /// + /// Returns a tuple containing a reference to an array of `Datum` and a reference to an array of `bool`, + /// or `None` if no data is found. fn fetch(&mut self, key: [u16; 3]) -> Option<(&[Datum; 32], &[bool; 32])> { unsafe { let mut ctid = key_to_ctid(key); @@ -581,7 +588,7 @@ impl SearchFetcher for HeapFetcher { if !fetch_row_version(self.heap_relation, &mut ctid, self.snapshot, self.slot) { return None; } - if !self.hack.is_null() && prererank_filtering() { + if !self.hack.is_null() && prefilter() { if let Some(qual) = NonNull::new((*self.hack).ss.ps.qual) { use pgrx::datum::FromDatum; use pgrx::memcxt::PgMemoryContexts; @@ -615,6 +622,51 @@ impl SearchFetcher for HeapFetcher { Some((&self.values, &self.is_nulls)) } } + + /// Determines whether the given `key` passes the filter criteria. + /// + /// # Parameters + /// - `key`: A 3-element array representing the key to be checked. + /// + /// # Returns + /// - `true` if the key satisfies the filter conditions. + /// - `false` otherwise. + fn filter(&self, key: [u16; 3]) -> bool { + if self.hack.is_null() || !prefilter() { + return true; + } + unsafe { + let mut ctid = key_to_ctid(key); + let table_am = (*self.heap_relation).rd_tableam; + let fetch_row_version = (*table_am) + .tuple_fetch_row_version + .expect("unsupported heap access method"); + if !fetch_row_version(self.heap_relation, &mut ctid, self.snapshot, self.slot) { + return false; + } + if let Some(qual) = NonNull::new((*self.hack).ss.ps.qual) { + use pgrx::datum::FromDatum; + use pgrx::memcxt::PgMemoryContexts; + assert!(qual.as_ref().flags & pgrx::pg_sys::EEO_FLAG_IS_QUAL as u8 != 0); + let evalfunc = qual.as_ref().evalfunc.expect("no evalfunc for qual"); + if !(*self.hack).ss.ps.ps_ExprContext.is_null() { + let econtext = (*self.hack).ss.ps.ps_ExprContext; + (*econtext).ecxt_scantuple = self.slot; + pgrx::pg_sys::MemoryContextReset((*econtext).ecxt_per_tuple_memory); + let result = PgMemoryContexts::For((*econtext).ecxt_per_tuple_memory) + .switch_to(|_| { + let mut is_null = true; + let datum = evalfunc(qual.as_ptr(), econtext, &mut is_null); + bool::from_datum(datum, is_null) + }); + if result != Some(true) { + return false; + } + } + } + true + } + } } struct Index { diff --git a/src/index/gucs.rs b/src/index/gucs.rs index 7a42cee5..d91a04ee 100644 --- a/src/index/gucs.rs +++ b/src/index/gucs.rs @@ -36,7 +36,7 @@ static MAX_SCAN_TUPLES: GucSetting = GucSetting::::new(-1); static MAXSIM_REFINE: GucSetting = GucSetting::::new(0); static MAXSIM_THRESHOLD: GucSetting = GucSetting::::new(0); -static PRERERANK_FILTERING: GucSetting = GucSetting::::new(false); +static PREFILTER: GucSetting = GucSetting::::new(false); static IO_SEARCH: GucSetting = GucSetting::::new( #[cfg(any(feature = "pg13", feature = "pg14", feature = "pg15", feature = "pg16"))] @@ -110,10 +110,10 @@ pub fn init() { GucFlags::default(), ); GucRegistry::define_bool_guc( - "vchordrq.prererank_filtering", - "`prererank_filtering` argument of vchordrq.", - "`prererank_filtering` argument of vchordrq.", - &PRERERANK_FILTERING, + "vchordrq.prefilter", + "`prefilter` argument of vchordrq.", + "`prefilter` argument of vchordrq.", + &PREFILTER, GucContext::Userset, GucFlags::default(), ); @@ -207,8 +207,8 @@ pub fn prewarm_dim() -> Vec { } } -pub fn prererank_filtering() -> bool { - PRERERANK_FILTERING.get() +pub fn prefilter() -> bool { + PREFILTER.get() } pub fn io_search() -> Io { diff --git a/src/index/scanners/default.rs b/src/index/scanners/default.rs index 57bda84d..dbdd209b 100644 --- a/src/index/scanners/default.rs +++ b/src/index/scanners/default.rs @@ -15,11 +15,15 @@ use super::{Io, SearchBuilder, SearchFetcher, SearchOptions}; use crate::index::algorithm::*; use crate::index::am::pointer_to_kv; +use crate::index::gucs::prefilter; use crate::index::opclass::{Opfamily, Sphere}; -use algorithm::operator::{Dot, L2, Op}; +use algorithm::operator::{Dot, L2, Op, Operator}; use algorithm::types::{DistanceKind, OwnedVector, VectorKind}; use algorithm::*; +use always_equal::AlwaysEqual; +use distance::Distance; use half::f16; +use std::cmp::Reverse; use std::collections::BinaryHeap; use std::num::NonZero; use vector::VectorOwned; @@ -67,7 +71,7 @@ impl SearchBuilder for DefaultBuilder { self, index: &'a R, options: SearchOptions, - mut fetcher: impl SearchFetcher + 'a, + fetcher: impl SearchFetcher + 'a, bump: &'a impl Bump, ) -> Box + 'a> where @@ -104,14 +108,12 @@ impl SearchBuilder for DefaultBuilder { let iter: Box)>> = match (opfamily.vector_kind(), opfamily.distance_kind()) { (VectorKind::Vecf32, DistanceKind::L2) => { - let vector = RandomProject::project( - if let OwnedVector::Vecf32(vector) = vector { - vector - } else { - unreachable!() - } - .as_borrowed(), - ); + let original_vector = if let OwnedVector::Vecf32(vector) = vector { + vector + } else { + unreachable!() + }; + let vector = RandomProject::project(original_vector.as_borrowed()); let results = match options.io_search { Io::Plain => default_search::<_, Op, L2>>( index, @@ -141,78 +143,32 @@ impl SearchBuilder for DefaultBuilder { make_h0_stream_prefetcher, ), }; - let fetch = move |payload| { - let (key, _) = pointer_to_kv(payload); - let (datums, is_nulls) = fetcher.fetch(key)?; - let datum = (!is_nulls[0]).then_some(datums[0]); - let maybe_vector = unsafe { datum.and_then(|x| opfamily.input_vector(x)) }; - let raw = if let OwnedVector::Vecf32(vector) = maybe_vector.unwrap() { - vector + let method = how(index); + let vector_insider = |vector: OwnedVector| { + if let OwnedVector::Vecf32(v) = vector { + v } else { unreachable!() - }; - Some(RandomProject::project(raw.as_borrowed())) - }; - let method = how(index); - match (method, options.io_rerank) { - (RerankMethod::Index, Io::Plain) => { - let prefetcher = - PlainPrefetcher::<_, BinaryHeap<_>>::new(index, results.into()); - Box::new( - rerank_index::, L2>, _, _>(vector, prefetcher) - .map(move |(distance, payload)| { - (opfamily.output(distance), payload) - }), - ) } - (RerankMethod::Index, Io::Simple) => { - let prefetcher = SimplePrefetcher::<'a, R, BinaryHeap<_>>::new( - index, - results.into(), - ); - Box::new( - rerank_index::, L2>, _, _>(vector, prefetcher) - .map(move |(distance, payload)| { - (opfamily.output(distance), payload) - }), - ) - } - (RerankMethod::Index, Io::Stream) => { - let prefetcher = StreamPrefetcher::<_, BinaryHeap<_>>::new( - index, - results.into(), - Hints::default(), - ); - Box::new( - rerank_index::, L2>, _, _>(vector, prefetcher) - .map(move |(distance, payload)| { - (opfamily.output(distance), payload) - }), - ) - } - (RerankMethod::Heap, _) => { - let prefetcher = - PlainPrefetcher::<_, BinaryHeap<_>>::new(index, results.into()); - Box::new( - rerank_heap::, L2>, _, _>( - vector, prefetcher, fetch, - ) - .map(move |(distance, payload)| { - (opfamily.output(distance), payload) - }), - ) - } - } + }; + rerank_wrapper::, L2>, _, _>( + original_vector, + index, + opfamily, + fetcher, + results, + method, + options.io_rerank, + vector_insider, + ) } (VectorKind::Vecf32, DistanceKind::Dot) => { - let vector = RandomProject::project( - if let OwnedVector::Vecf32(vector) = vector { - vector - } else { - unreachable!() - } - .as_borrowed(), - ); + let original_vector = if let OwnedVector::Vecf32(vector) = vector { + vector + } else { + unreachable!() + }; + let vector = RandomProject::project(original_vector.as_borrowed()); let results = match options.io_search { Io::Plain => default_search::<_, Op, Dot>>( index, @@ -242,76 +198,32 @@ impl SearchBuilder for DefaultBuilder { make_h0_stream_prefetcher, ), }; - let fetch = move |payload| { - let (key, _) = pointer_to_kv(payload); - let (datums, is_nulls) = fetcher.fetch(key)?; - let datum = (!is_nulls[0]).then_some(datums[0]); - let maybe_vector = unsafe { datum.and_then(|x| opfamily.input_vector(x)) }; - let raw = if let OwnedVector::Vecf32(vector) = maybe_vector.unwrap() { - vector + let method = how(index); + let vector_insider = |vector: OwnedVector| { + if let OwnedVector::Vecf32(v) = vector { + v } else { unreachable!() - }; - Some(RandomProject::project(raw.as_borrowed())) - }; - let method = how(index); - match (method, options.io_rerank) { - (RerankMethod::Index, Io::Plain) => { - let prefetcher = - PlainPrefetcher::<_, BinaryHeap<_>>::new(index, results.into()); - Box::new( - rerank_index::, Dot>, _, _>(vector, prefetcher) - .map(move |(distance, payload)| { - (opfamily.output(distance), payload) - }), - ) - } - (RerankMethod::Index, Io::Simple) => { - let prefetcher = - SimplePrefetcher::<_, BinaryHeap<_>>::new(index, results.into()); - Box::new( - rerank_index::, Dot>, _, _>(vector, prefetcher) - .map(move |(distance, payload)| { - (opfamily.output(distance), payload) - }), - ) - } - (RerankMethod::Index, Io::Stream) => { - let prefetcher = StreamPrefetcher::<_, BinaryHeap<_>>::new( - index, - results.into(), - Hints::default(), - ); - Box::new( - rerank_index::, Dot>, _, _>(vector, prefetcher) - .map(move |(distance, payload)| { - (opfamily.output(distance), payload) - }), - ) - } - (RerankMethod::Heap, _) => { - let prefetcher = - PlainPrefetcher::<_, BinaryHeap<_>>::new(index, results.into()); - Box::new( - rerank_heap::, Dot>, _, _>( - vector, prefetcher, fetch, - ) - .map(move |(distance, payload)| { - (opfamily.output(distance), payload) - }), - ) } - } + }; + rerank_wrapper::, Dot>, _, _>( + original_vector, + index, + opfamily, + fetcher, + results, + method, + options.io_rerank, + vector_insider, + ) } (VectorKind::Vecf16, DistanceKind::L2) => { - let vector = RandomProject::project( - if let OwnedVector::Vecf16(vector) = vector { - vector - } else { - unreachable!() - } - .as_borrowed(), - ); + let original_vector = if let OwnedVector::Vecf16(vector) = vector { + vector + } else { + unreachable!() + }; + let vector = RandomProject::project(original_vector.as_borrowed()); let results = match options.io_search { Io::Plain => default_search::<_, Op, L2>>( index, @@ -341,76 +253,32 @@ impl SearchBuilder for DefaultBuilder { make_h0_stream_prefetcher, ), }; - let fetch = move |payload| { - let (key, _) = pointer_to_kv(payload); - let (datums, is_nulls) = fetcher.fetch(key)?; - let datum = (!is_nulls[0]).then_some(datums[0]); - let maybe_vector = unsafe { datum.and_then(|x| opfamily.input_vector(x)) }; - let raw = if let OwnedVector::Vecf16(vector) = maybe_vector.unwrap() { - vector + let method = how(index); + let vector_insider = |vector: OwnedVector| { + if let OwnedVector::Vecf16(v) = vector { + v } else { unreachable!() - }; - Some(RandomProject::project(raw.as_borrowed())) - }; - let method = how(index); - match (method, options.io_rerank) { - (RerankMethod::Index, Io::Plain) => { - let prefetcher = - PlainPrefetcher::<_, BinaryHeap<_>>::new(index, results.into()); - Box::new( - rerank_index::, L2>, _, _>(vector, prefetcher) - .map(move |(distance, payload)| { - (opfamily.output(distance), payload) - }), - ) - } - (RerankMethod::Index, Io::Simple) => { - let prefetcher = - SimplePrefetcher::<_, BinaryHeap<_>>::new(index, results.into()); - Box::new( - rerank_index::, L2>, _, _>(vector, prefetcher) - .map(move |(distance, payload)| { - (opfamily.output(distance), payload) - }), - ) - } - (RerankMethod::Index, Io::Stream) => { - let prefetcher = StreamPrefetcher::<_, BinaryHeap<_>>::new( - index, - results.into(), - Hints::default(), - ); - Box::new( - rerank_index::, L2>, _, _>(vector, prefetcher) - .map(move |(distance, payload)| { - (opfamily.output(distance), payload) - }), - ) } - (RerankMethod::Heap, _) => { - let prefetcher = - PlainPrefetcher::<_, BinaryHeap<_>>::new(index, results.into()); - Box::new( - rerank_heap::, L2>, _, _>( - vector, prefetcher, fetch, - ) - .map(move |(distance, payload)| { - (opfamily.output(distance), payload) - }), - ) - } - } + }; + rerank_wrapper::, L2>, _, _>( + original_vector, + index, + opfamily, + fetcher, + results, + method, + options.io_rerank, + vector_insider, + ) } (VectorKind::Vecf16, DistanceKind::Dot) => { - let vector = RandomProject::project( - if let OwnedVector::Vecf16(vector) = vector { - vector - } else { - unreachable!() - } - .as_borrowed(), - ); + let original_vector = if let OwnedVector::Vecf16(vector) = vector { + vector + } else { + unreachable!() + }; + let vector = RandomProject::project(original_vector.as_borrowed()); let results = match options.io_search { Io::Plain => default_search::<_, Op, Dot>>( index, @@ -440,66 +308,24 @@ impl SearchBuilder for DefaultBuilder { make_h0_stream_prefetcher, ), }; - let fetch = move |payload| { - let (key, _) = pointer_to_kv(payload); - let (datums, is_nulls) = fetcher.fetch(key)?; - let datum = (!is_nulls[0]).then_some(datums[0]); - let maybe_vector = unsafe { datum.and_then(|x| opfamily.input_vector(x)) }; - let raw = if let OwnedVector::Vecf16(vector) = maybe_vector.unwrap() { - vector + let method = how(index); + let vector_insider = |vector: OwnedVector| { + if let OwnedVector::Vecf16(v) = vector { + v } else { unreachable!() - }; - Some(RandomProject::project(raw.as_borrowed())) - }; - let method = how(index); - match (method, options.io_rerank) { - (RerankMethod::Index, Io::Plain) => { - let prefetcher = - PlainPrefetcher::<_, BinaryHeap<_>>::new(index, results.into()); - Box::new( - rerank_index::, Dot>, _, _>(vector, prefetcher) - .map(move |(distance, payload)| { - (opfamily.output(distance), payload) - }), - ) - } - (RerankMethod::Index, Io::Simple) => { - let prefetcher = - SimplePrefetcher::<_, BinaryHeap<_>>::new(index, results.into()); - Box::new( - rerank_index::, Dot>, _, _>(vector, prefetcher) - .map(move |(distance, payload)| { - (opfamily.output(distance), payload) - }), - ) } - (RerankMethod::Index, Io::Stream) => { - let prefetcher = StreamPrefetcher::<_, BinaryHeap<_>>::new( - index, - results.into(), - Hints::default(), - ); - Box::new( - rerank_index::, Dot>, _, _>(vector, prefetcher) - .map(move |(distance, payload)| { - (opfamily.output(distance), payload) - }), - ) - } - (RerankMethod::Heap, _) => { - let prefetcher = - PlainPrefetcher::<_, BinaryHeap<_>>::new(index, results.into()); - Box::new( - rerank_heap::, Dot>, _, _>( - vector, prefetcher, fetch, - ) - .map(move |(distance, payload)| { - (opfamily.output(distance), payload) - }), - ) - } - } + }; + rerank_wrapper::, Dot>, _, _>( + original_vector, + index, + opfamily, + fetcher, + results, + method, + options.io_rerank, + vector_insider, + ) } }; let iter = if let Some(threshold) = threshold { @@ -518,3 +344,73 @@ impl SearchBuilder for DefaultBuilder { })) } } + +type SeqElement<'a> = ( + (Reverse, AlwaysEqual<()>), + AlwaysEqual<&'a mut (NonZero, u16, &'a mut [u32])>, +); + +#[allow(clippy::too_many_arguments)] +fn rerank_wrapper<'a, O: Operator>, R, T>( + vector: O::Vector, + index: &'a R, + opfamily: Opfamily, + mut fetcher: impl SearchFetcher + 'a, + results: Vec>, + method: RerankMethod, + io_rerank: Io, + vector_insider: impl Fn(OwnedVector) -> VectOwned + 'a, +) -> Box)> + 'a> +where + R: RelationRead + RelationPrefetch + RelationReadStream, +{ + match (method, io_rerank) { + (RerankMethod::Index, Io::Plain) => { + let seq = seq_filter(BinaryHeap::from(results), prefilter(), move |key| { + let (key, _) = pointer_to_kv(key.1.0.0); + fetcher.filter(key) + }); + let prefetcher = PlainPrefetcher::<_, _>::new(index, seq); + Box::new( + rerank_index::(vector, prefetcher) + .map(move |(distance, payload)| (opfamily.output(distance), payload)), + ) + } + (RerankMethod::Index, Io::Simple) => { + let seq = seq_filter(BinaryHeap::from(results), prefilter(), move |key| { + let (key, _) = pointer_to_kv(key.1.0.0); + fetcher.filter(key) + }); + let prefetcher = SimplePrefetcher::<'a, R, _>::new(index, seq); + Box::new( + rerank_index::(vector, prefetcher) + .map(move |(distance, payload)| (opfamily.output(distance), payload)), + ) + } + (RerankMethod::Index, Io::Stream) => { + let seq = seq_filter(BinaryHeap::from(results), prefilter(), move |key| { + let (key, _) = pointer_to_kv(key.1.0.0); + fetcher.filter(key) + }); + let prefetcher = StreamPrefetcher::<_, _>::new(index, seq, Hints::default()); + Box::new( + rerank_index::(vector, prefetcher) + .map(move |(distance, payload)| (opfamily.output(distance), payload)), + ) + } + (RerankMethod::Heap, _) => { + let fetch = move |payload| { + let (key, _) = pointer_to_kv(payload); + let (datums, is_nulls) = fetcher.fetch(key)?; + let datum = (!is_nulls[0]).then_some(datums[0]); + let maybe_vector = unsafe { datum.and_then(|x| opfamily.input_vector(x)) }; + Some(vector_insider(maybe_vector.unwrap())) + }; + let prefetcher = PlainPrefetcher::<_, BinaryHeap<_>>::new(index, results.into()); + Box::new( + rerank_heap::(vector, prefetcher, fetch) + .map(move |(distance, payload)| (opfamily.output(distance), payload)), + ) + } + } +} diff --git a/src/index/scanners/maxsim.rs b/src/index/scanners/maxsim.rs index 5bf8a63d..febd5e1a 100644 --- a/src/index/scanners/maxsim.rs +++ b/src/index/scanners/maxsim.rs @@ -15,6 +15,7 @@ use super::{SearchBuilder, SearchFetcher, SearchOptions}; use crate::index::algorithm::{RandomProject, *}; use crate::index::am::pointer_to_kv; +use crate::index::gucs::prefilter; use crate::index::opclass::Opfamily; use crate::index::scanners::Io; use algorithm::operator::Dot; @@ -60,7 +61,7 @@ impl SearchBuilder for MaxsimBuilder { self, index: &'a R, options: SearchOptions, - _fetcher: impl SearchFetcher + 'a, + fetcher: impl SearchFetcher + 'a, bump: &'a impl Bump, ) -> Box + 'a> where @@ -104,191 +105,203 @@ impl SearchBuilder for MaxsimBuilder { let iter: Box> = match opfamily.vector_kind() { VectorKind::Vecf32 => { type Op = operator::Op, Dot>; - let vectors = vectors + let original_vectors = vectors .into_iter() .map(|vector| { - RandomProject::project( - if let OwnedVector::Vecf32(vector) = vector { - vector - } else { - unreachable!() - } - .as_borrowed(), - ) + if let OwnedVector::Vecf32(vector) = vector { + vector + } else { + unreachable!() + } }) .collect::>(); - Box::new(vectors.into_iter().map(|vector| { - let (results, estimation_by_threshold) = match options.io_search { - Io::Plain => maxsim_search::<_, Op>( - index, - vector.clone(), - options.probes.clone(), - options.epsilon, - maxsim_threshold, - bump, - make_h1_plain_prefetcher.clone(), - make_h0_plain_prefetcher.clone(), - ), - Io::Simple => maxsim_search::<_, Op>( - index, - vector.clone(), - options.probes.clone(), - options.epsilon, - maxsim_threshold, - bump, - make_h1_plain_prefetcher.clone(), - make_h0_simple_prefetcher.clone(), - ), - Io::Stream => maxsim_search::<_, Op>( - index, - vector.clone(), - options.probes.clone(), - options.epsilon, - maxsim_threshold, - bump, - make_h1_plain_prefetcher.clone(), - make_h0_stream_prefetcher.clone(), - ), - }; - let (mut accu_set, mut rough_set) = (Vec::new(), Vec::new()); - if maxsim_refine != 0 && !results.is_empty() { - match options.io_rerank { - Io::Plain => { - let prefetcher = - PlainPrefetcher::<_, BinaryHeap<_>>::new(index, results.into()); - let mut reranker = - rerank_index::(vector.clone(), prefetcher); - accu_set.extend(reranker.by_ref().take(maxsim_refine as _)); - let (rough_iter, accu_iter) = reranker.finish(); - accu_set.extend(accu_iter.map(accu_map)); - rough_set.extend(rough_iter.into_iter().map(rough_map)); - } - Io::Simple => { - let prefetcher = SimplePrefetcher::<'a, R, BinaryHeap<_>>::new( - index, - results.into(), - ); - let mut reranker = - rerank_index::(vector.clone(), prefetcher); - accu_set.extend(reranker.by_ref().take(maxsim_refine as _)); - let (rough_iter, accu_iter) = reranker.finish(); - accu_set.extend(accu_iter.map(accu_map)); - rough_set.extend(rough_iter.into_iter().map(rough_map)); - } - Io::Stream => { - let prefetcher = StreamPrefetcher::<_, BinaryHeap<_>>::new( - index, - results.into(), - Hints::default(), - ); - let mut reranker = - rerank_index::(vector.clone(), prefetcher); - accu_set.extend(reranker.by_ref().take(maxsim_refine as _)); - let (rough_iter, accu_iter) = reranker.finish(); - accu_set.extend(accu_iter.map(accu_map)); - rough_set.extend(rough_iter.into_iter().map(rough_map)); + let vectors = original_vectors + .clone() + .into_iter() + .map(|vector| RandomProject::project(vector.as_borrowed())); + Box::new(vectors.into_iter().zip(original_vectors).map( + |(vector, original_vector)| { + let (results, estimation_by_threshold) = match options.io_search { + Io::Plain => maxsim_search::<_, Op>( + index, + vector.clone(), + options.probes.clone(), + options.epsilon, + maxsim_threshold, + bump, + make_h1_plain_prefetcher.clone(), + make_h0_plain_prefetcher.clone(), + ), + Io::Simple => maxsim_search::<_, Op>( + index, + vector.clone(), + options.probes.clone(), + options.epsilon, + maxsim_threshold, + bump, + make_h1_plain_prefetcher.clone(), + make_h0_simple_prefetcher.clone(), + ), + Io::Stream => maxsim_search::<_, Op>( + index, + vector.clone(), + options.probes.clone(), + options.epsilon, + maxsim_threshold, + bump, + make_h1_plain_prefetcher.clone(), + make_h0_stream_prefetcher.clone(), + ), + }; + let (mut accu_set, mut rough_set) = (Vec::new(), Vec::new()); + if maxsim_refine != 0 && !results.is_empty() { + let seq = seq_filter(BinaryHeap::from(results), prefilter(), |key| { + let (key, _) = pointer_to_kv(key.1.0.0); + fetcher.filter(key) + }); + match options.io_rerank { + Io::Plain => { + let prefetcher = PlainPrefetcher::<_, _>::new(index, seq); + let mut reranker = rerank_index::( + original_vector.clone(), + prefetcher, + ); + accu_set.extend(reranker.by_ref().take(maxsim_refine as _)); + let (rough_iter, accu_iter) = reranker.finish(); + accu_set.extend(accu_iter.map(accu_map)); + rough_set.extend(rough_iter.into_iter().map(rough_map)); + } + Io::Simple => { + let prefetcher = SimplePrefetcher::<'a, R, _>::new(index, seq); + let mut reranker = rerank_index::( + original_vector.clone(), + prefetcher, + ); + accu_set.extend(reranker.by_ref().take(maxsim_refine as _)); + let (rough_iter, accu_iter) = reranker.finish(); + accu_set.extend(accu_iter.map(accu_map)); + rough_set.extend(rough_iter.into_iter().map(rough_map)); + } + Io::Stream => { + let prefetcher = + StreamPrefetcher::<_, _>::new(index, seq, Hints::default()); + let mut reranker = rerank_index::( + original_vector.clone(), + prefetcher, + ); + accu_set.extend(reranker.by_ref().take(maxsim_refine as _)); + let (rough_iter, accu_iter) = reranker.finish(); + accu_set.extend(accu_iter.map(accu_map)); + rough_set.extend(rough_iter.into_iter().map(rough_map)); + } } + } else { + let rough_iter = results.into_iter(); + rough_set.extend(rough_iter.map(rough_map)); } - } else { - let rough_iter = results.into_iter(); - rough_set.extend(rough_iter.map(rough_map)); - } - (accu_set, rough_set, estimation_by_threshold) - })) + (accu_set, rough_set, estimation_by_threshold) + }, + )) } VectorKind::Vecf16 => { type Op = operator::Op, Dot>; - let vectors = vectors + let original_vectors = vectors .into_iter() .map(|vector| { - RandomProject::project( - if let OwnedVector::Vecf16(vector) = vector { - vector - } else { - unreachable!() - } - .as_borrowed(), - ) + if let OwnedVector::Vecf16(vector) = vector { + vector + } else { + unreachable!() + } }) .collect::>(); - Box::new(vectors.into_iter().map(|vector| { - let (results, estimation_by_threshold) = match options.io_search { - Io::Plain => maxsim_search::<_, Op>( - index, - vector.clone(), - options.probes.clone(), - options.epsilon, - maxsim_threshold, - bump, - make_h1_plain_prefetcher.clone(), - make_h0_plain_prefetcher.clone(), - ), - Io::Simple => maxsim_search::<_, Op>( - index, - vector.clone(), - options.probes.clone(), - options.epsilon, - maxsim_threshold, - bump, - make_h1_plain_prefetcher.clone(), - make_h0_simple_prefetcher.clone(), - ), - Io::Stream => maxsim_search::<_, Op>( - index, - vector.clone(), - options.probes.clone(), - options.epsilon, - maxsim_threshold, - bump, - make_h1_plain_prefetcher.clone(), - make_h0_stream_prefetcher.clone(), - ), - }; - let (mut accu_set, mut rough_set) = (Vec::new(), Vec::new()); - if maxsim_refine != 0 && !results.is_empty() { - match options.io_rerank { - Io::Plain => { - let prefetcher = - PlainPrefetcher::<_, BinaryHeap<_>>::new(index, results.into()); - let mut reranker = - rerank_index::(vector.clone(), prefetcher); - accu_set.extend(reranker.by_ref().take(maxsim_refine as _)); - let (rough_iter, accu_iter) = reranker.finish(); - accu_set.extend(accu_iter.map(accu_map)); - rough_set.extend(rough_iter.into_iter().map(rough_map)); - } - Io::Simple => { - let prefetcher = SimplePrefetcher::<'a, R, BinaryHeap<_>>::new( - index, - results.into(), - ); - let mut reranker = - rerank_index::(vector.clone(), prefetcher); - accu_set.extend(reranker.by_ref().take(maxsim_refine as _)); - let (rough_iter, accu_iter) = reranker.finish(); - accu_set.extend(accu_iter.map(accu_map)); - rough_set.extend(rough_iter.into_iter().map(rough_map)); - } - Io::Stream => { - let prefetcher = StreamPrefetcher::<_, BinaryHeap<_>>::new( - index, - results.into(), - Hints::default(), - ); - let mut reranker = - rerank_index::(vector.clone(), prefetcher); - accu_set.extend(reranker.by_ref().take(maxsim_refine as _)); - let (rough_iter, accu_iter) = reranker.finish(); - accu_set.extend(accu_iter.map(accu_map)); - rough_set.extend(rough_iter.into_iter().map(rough_map)); + let vectors = original_vectors + .clone() + .into_iter() + .map(|vector| RandomProject::project(vector.as_borrowed())); + Box::new(vectors.into_iter().zip(original_vectors).map( + |(vector, original_vector)| { + let (results, estimation_by_threshold) = match options.io_search { + Io::Plain => maxsim_search::<_, Op>( + index, + vector.clone(), + options.probes.clone(), + options.epsilon, + maxsim_threshold, + bump, + make_h1_plain_prefetcher.clone(), + make_h0_plain_prefetcher.clone(), + ), + Io::Simple => maxsim_search::<_, Op>( + index, + vector.clone(), + options.probes.clone(), + options.epsilon, + maxsim_threshold, + bump, + make_h1_plain_prefetcher.clone(), + make_h0_simple_prefetcher.clone(), + ), + Io::Stream => maxsim_search::<_, Op>( + index, + vector.clone(), + options.probes.clone(), + options.epsilon, + maxsim_threshold, + bump, + make_h1_plain_prefetcher.clone(), + make_h0_stream_prefetcher.clone(), + ), + }; + let (mut accu_set, mut rough_set) = (Vec::new(), Vec::new()); + if maxsim_refine != 0 && !results.is_empty() { + let seq = seq_filter(BinaryHeap::from(results), prefilter(), |key| { + let (key, _) = pointer_to_kv(key.1.0.0); + fetcher.filter(key) + }); + match options.io_rerank { + Io::Plain => { + let prefetcher = PlainPrefetcher::<_, _>::new(index, seq); + let mut reranker = rerank_index::( + original_vector.clone(), + prefetcher, + ); + accu_set.extend(reranker.by_ref().take(maxsim_refine as _)); + let (rough_iter, accu_iter) = reranker.finish(); + accu_set.extend(accu_iter.map(accu_map)); + rough_set.extend(rough_iter.into_iter().map(rough_map)); + } + Io::Simple => { + let prefetcher = SimplePrefetcher::<'a, R, _>::new(index, seq); + let mut reranker = rerank_index::( + original_vector.clone(), + prefetcher, + ); + accu_set.extend(reranker.by_ref().take(maxsim_refine as _)); + let (rough_iter, accu_iter) = reranker.finish(); + accu_set.extend(accu_iter.map(accu_map)); + rough_set.extend(rough_iter.into_iter().map(rough_map)); + } + Io::Stream => { + let prefetcher = + StreamPrefetcher::<_, _>::new(index, seq, Hints::default()); + let mut reranker = rerank_index::( + original_vector.clone(), + prefetcher, + ); + accu_set.extend(reranker.by_ref().take(maxsim_refine as _)); + let (rough_iter, accu_iter) = reranker.finish(); + accu_set.extend(accu_iter.map(accu_map)); + rough_set.extend(rough_iter.into_iter().map(rough_map)); + } } + } else { + let rough_iter = results.into_iter(); + rough_set.extend(rough_iter.map(rough_map)); } - } else { - let rough_iter = results.into_iter(); - rough_set.extend(rough_iter.map(rough_map)); - } - (accu_set, rough_set, estimation_by_threshold) - })) + (accu_set, rough_set, estimation_by_threshold) + }, + )) } }; let mut updates = Vec::new(); diff --git a/src/index/scanners/mod.rs b/src/index/scanners/mod.rs index 9be42971..aa4a2b5f 100644 --- a/src/index/scanners/mod.rs +++ b/src/index/scanners/mod.rs @@ -63,10 +63,14 @@ pub trait SearchBuilder: 'static { pub trait SearchFetcher { fn fetch(&mut self, ctid: [u16; 3]) -> Option<(&[Datum; 32], &[bool; 32])>; + fn filter(&self, key: [u16; 3]) -> bool; } impl T> SearchFetcher for LazyCell { fn fetch(&mut self, key: [u16; 3]) -> Option<(&[Datum; 32], &[bool; 32])> { LazyCell::force_mut(self).fetch(key) } + fn filter(&self, key: [u16; 3]) -> bool { + LazyCell::force(self).filter(key) + } } diff --git a/tests/logic/distance.slt b/tests/general/distance.slt similarity index 100% rename from tests/logic/distance.slt rename to tests/general/distance.slt diff --git a/tests/logic/external_build.slt b/tests/general/external_build.slt similarity index 95% rename from tests/logic/external_build.slt rename to tests/general/external_build.slt index 81dc35d8..0fb9d7f1 100644 --- a/tests/logic/external_build.slt +++ b/tests/general/external_build.slt @@ -1,3 +1,6 @@ +statement ok +DROP TABLE IF EXISTS t, vector_centroid, halfvec_centroid, real_centroid, bad_type_centroid, bad_duplicate_id; + statement ok CREATE TABLE t (val0 vector(3), val1 halfvec(3)); diff --git a/tests/logic/index.slt b/tests/general/index.slt similarity index 100% rename from tests/logic/index.slt rename to tests/general/index.slt diff --git a/tests/logic/issue427.slt b/tests/general/issue427.slt similarity index 100% rename from tests/logic/issue427.slt rename to tests/general/issue427.slt diff --git a/tests/logic/multivector.slt b/tests/general/multivector.slt similarity index 100% rename from tests/logic/multivector.slt rename to tests/general/multivector.slt diff --git a/tests/logic/null.fail b/tests/general/null.fail similarity index 100% rename from tests/logic/null.fail rename to tests/general/null.fail diff --git a/tests/logic/partition.slt b/tests/general/partition.slt similarity index 95% rename from tests/logic/partition.slt rename to tests/general/partition.slt index 1e681c5e..a3d9a220 100644 --- a/tests/logic/partition.slt +++ b/tests/general/partition.slt @@ -1,3 +1,6 @@ +statement ok +DROP TABLE IF EXISTS id_789, id_456, id_123, t; + # partition table statement ok CREATE TABLE t (val vector(3), category_id int) PARTITION BY LIST(category_id); diff --git a/tests/logic/pin.slt b/tests/general/pin.slt similarity index 100% rename from tests/logic/pin.slt rename to tests/general/pin.slt diff --git a/tests/logic/pushdown_plan.slt b/tests/general/pushdown_plan.slt similarity index 100% rename from tests/logic/pushdown_plan.slt rename to tests/general/pushdown_plan.slt diff --git a/tests/logic/pushdown_range.slt b/tests/general/pushdown_range.slt similarity index 100% rename from tests/logic/pushdown_range.slt rename to tests/general/pushdown_range.slt diff --git a/tests/logic/reindex.slt b/tests/general/reindex.slt similarity index 100% rename from tests/logic/reindex.slt rename to tests/general/reindex.slt diff --git a/tests/general/rerank_in_index.slt b/tests/general/rerank_in_index.slt new file mode 100644 index 00000000..9258f44d --- /dev/null +++ b/tests/general/rerank_in_index.slt @@ -0,0 +1,104 @@ +statement ok +DROP TABLE IF EXISTS t_expr, t_column; + +statement ok +SET enable_seqscan = off; + +statement ok +CREATE TABLE t_column (id integer, val vector(3)); + +statement ok +INSERT INTO t_column (id, val) SELECT id, ARRAY[id, id, id]::real[] FROM generate_series(1, 10000) s(id); + +statement ok +CREATE INDEX ON t_column USING vchordrq (val vector_l2_ops) +WITH (options = $$ +residual_quantization = false +rerank_in_table = false +[build.internal] +lists = [] +$$); + +statement ok +SET vchordrq.probes = ''; + +query I +SELECT id FROM t_column ORDER BY val <-> '[1.9, 1.9, 1.9]' limit 9; +---- +2 +1 +3 +4 +5 +6 +7 +8 +9 + +statement ok +DROP TABLE t_column; + +statement ok +CREATE TABLE t_expr (id integer); + +statement ok +INSERT INTO t_expr (id) SELECT id FROM generate_series(1, 10000) s(id); + +statement ok +CREATE INDEX ON t_expr USING vchordrq ((ARRAY[id::real, id::real, id::real]::vector(3)) vector_l2_ops) +WITH (options = $$ +residual_quantization = false +rerank_in_table = false +[build.internal] +lists = [] +$$); + +statement ok +SET vchordrq.probes = ''; + +query I +SELECT id FROM t_expr ORDER BY ARRAY[id::real, id::real, id::real]::vector(3) <-> '[1.9, 1.9, 1.9]' limit 9; +---- +2 +1 +3 +4 +5 +6 +7 +8 +9 + +query I +SELECT id FROM t_expr WHERE id <= 5 OR id % 2 = 1 ORDER BY ARRAY[id::real, id::real, id::real]::vector(3) <-> '[1.9, 1.9, 1.9]' LIMIT 9; +---- +2 +1 +3 +4 +5 +7 +9 +11 +13 + +statement ok +SET vchordrq.prefilter to off; + +query I +SELECT ARRAY(SELECT id FROM t_expr WHERE (id <= 5 OR id % 2 = 1) OR e >= 2000 ORDER BY ARRAY[id::real, id::real, id::real]::vector(3) <-> q LIMIT 9) FROM (VALUES ('[1.9,1.99,1.999]'::vector, 1999), ('[2.1,2.11,2.111]', 2111)) AS t(q, e); +---- +{2,1,3,4,5,7,9,11,13} +{2,3,1,4,5,6,7,8,9} + +statement ok +SET vchordrq.prefilter to on; + +query I +SELECT ARRAY(SELECT id FROM t_expr WHERE (id <= 5 OR id % 2 = 1) OR e >= 2000 ORDER BY ARRAY[id::real, id::real, id::real]::vector(3) <-> q LIMIT 9) FROM (VALUES ('[1.9,1.99,1.999]'::vector, 1999), ('[2.1,2.11,2.111]', 2111)) AS t(q, e); +---- +{2,1,3,4,5,7,9,11,13} +{2,3,1,4,5,6,7,8,9} + +statement ok +DROP TABLE t_expr; diff --git a/tests/logic/rerank_in_table.slt b/tests/general/rerank_in_table.slt similarity index 94% rename from tests/logic/rerank_in_table.slt rename to tests/general/rerank_in_table.slt index 8a2b8c44..8bab097c 100644 --- a/tests/logic/rerank_in_table.slt +++ b/tests/general/rerank_in_table.slt @@ -1,3 +1,6 @@ +statement ok +DROP TABLE IF EXISTS t_expr, t_column; + statement ok SET enable_seqscan = off; @@ -80,7 +83,7 @@ SELECT id FROM t_expr WHERE id <= 5 OR id % 2 = 1 ORDER BY ARRAY[id::real, id::r 13 statement ok -SET vchordrq.prererank_filtering to off; +SET vchordrq.prefilter to off; query I SELECT ARRAY(SELECT id FROM t_expr WHERE (id <= 5 OR id % 2 = 1) OR e >= 2000 ORDER BY ARRAY[id::real, id::real, id::real]::vector(3) <-> q LIMIT 9) FROM (VALUES ('[1.9,1.99,1.999]'::vector, 1999), ('[2.1,2.11,2.111]', 2111)) AS t(q, e); @@ -89,7 +92,7 @@ SELECT ARRAY(SELECT id FROM t_expr WHERE (id <= 5 OR id % 2 = 1) OR e >= 2000 OR {2,3,1,4,5,6,7,8,9} statement ok -SET vchordrq.prererank_filtering to on; +SET vchordrq.prefilter to on; query I SELECT ARRAY(SELECT id FROM t_expr WHERE (id <= 5 OR id % 2 = 1) OR e >= 2000 ORDER BY ARRAY[id::real, id::real, id::real]::vector(3) <-> q LIMIT 9) FROM (VALUES ('[1.9,1.99,1.999]'::vector, 1999), ('[2.1,2.11,2.111]', 2111)) AS t(q, e); diff --git a/tests/logic/vector.slt b/tests/general/vector.slt similarity index 100% rename from tests/logic/vector.slt rename to tests/general/vector.slt diff --git a/tests/pg16/filter_rerank_in_index.slt b/tests/pg16/filter_rerank_in_index.slt new file mode 100644 index 00000000..69c6e9e5 --- /dev/null +++ b/tests/pg16/filter_rerank_in_index.slt @@ -0,0 +1,101 @@ +statement ok +SET enable_seqscan = off; + +statement ok +CREATE TABLE t_expr (id integer); + +statement ok +INSERT INTO t_expr (id) SELECT id FROM generate_series(1, 10000) s(id); + +statement ok +CREATE INDEX ON t_expr USING vchordrq ((ARRAY[id::real, id::real, id::real]::vector(3)) vector_l2_ops) +WITH (options = $$ +residual_quantization = false +rerank_in_table = false +[build.internal] +lists = [] +$$); + +statement ok +SET vchordrq.probes = ''; + +# non-heap rerank + no prefetch + postfilter +statement ok +SET vchordrq.prefilter to off; + +statement ok +SET vchordrq.io_rerank to 'read_buffer'; + +query I +SELECT id FROM t_expr ORDER BY ARRAY[id::real, id::real, id::real]::vector(3) <-> '[1.9, 1.9, 1.9]' limit 9; +---- +2 +1 +3 +4 +5 +6 +7 +8 +9 + +query I +SELECT ARRAY(SELECT id FROM t_expr WHERE (id <= 5 OR id % 2 = 1) OR e >= 2000 ORDER BY ARRAY[id::real, id::real, id::real]::vector(3) <-> q LIMIT 9) FROM (VALUES ('[1.9,1.99,1.999]'::vector, 1999), ('[2.1,2.11,2.111]', 2111)) AS t(q, e); +---- +{2,1,3,4,5,7,9,11,13} +{2,3,1,4,5,6,7,8,9} + +# non-heap rerank + simple prefetch + postfilter +statement ok +SET vchordrq.prefilter to off; + +statement ok +SET vchordrq.io_search to 'prefetch_buffer'; + +query I +SELECT id FROM t_expr ORDER BY ARRAY[id::real, id::real, id::real]::vector(3) <-> '[1.9, 1.9, 1.9]' limit 9; +---- +2 +1 +3 +4 +5 +6 +7 +8 +9 + +query I +SELECT ARRAY(SELECT id FROM t_expr WHERE (id <= 5 OR id % 2 = 1) OR e >= 2000 ORDER BY ARRAY[id::real, id::real, id::real]::vector(3) <-> q LIMIT 9) FROM (VALUES ('[1.9,1.99,1.999]'::vector, 1999), ('[2.1,2.11,2.111]', 2111)) AS t(q, e); +---- +{2,1,3,4,5,7,9,11,13} +{2,3,1,4,5,6,7,8,9} + +# non-heap rerank + no prefetch + prefilter +statement ok +SET vchordrq.prefilter to on; + +statement ok +SET vchordrq.io_rerank to 'read_buffer'; + +query I +SELECT ARRAY(SELECT id FROM t_expr WHERE (id <= 5 OR id % 2 = 1) OR e >= 2000 ORDER BY ARRAY[id::real, id::real, id::real]::vector(3) <-> q LIMIT 9) FROM (VALUES ('[1.9,1.99,1.999]'::vector, 1999), ('[2.1,2.11,2.111]', 2111)) AS t(q, e); +---- +{2,1,3,4,5,7,9,11,13} +{2,3,1,4,5,6,7,8,9} + +# non-heap rerank + simple prefetch + prefilter +statement ok +SET vchordrq.prefilter to on; + +statement ok +SET vchordrq.io_search to 'prefetch_buffer'; + +query I +SELECT ARRAY(SELECT id FROM t_expr WHERE (id <= 5 OR id % 2 = 1) OR e >= 2000 ORDER BY ARRAY[id::real, id::real, id::real]::vector(3) <-> q LIMIT 9) FROM (VALUES ('[1.9,1.99,1.999]'::vector, 1999), ('[2.1,2.11,2.111]', 2111)) AS t(q, e); +---- +{2,1,3,4,5,7,9,11,13} +{2,3,1,4,5,6,7,8,9} + +statement ok +DROP TABLE t_expr; diff --git a/tests/pg16/filter_rerank_in_table.slt b/tests/pg16/filter_rerank_in_table.slt new file mode 100644 index 00000000..c556207b --- /dev/null +++ b/tests/pg16/filter_rerank_in_table.slt @@ -0,0 +1,101 @@ +statement ok +SET enable_seqscan = off; + +statement ok +CREATE TABLE t_expr (id integer); + +statement ok +INSERT INTO t_expr (id) SELECT id FROM generate_series(1, 10000) s(id); + +statement ok +CREATE INDEX ON t_expr USING vchordrq ((ARRAY[id::real, id::real, id::real]::vector(3)) vector_l2_ops) +WITH (options = $$ +residual_quantization = false +rerank_in_table = true +[build.internal] +lists = [] +$$); + +statement ok +SET vchordrq.probes = ''; + +# heap rerank + no prefetch + postfilter +statement ok +SET vchordrq.prefilter to off; + +statement ok +SET vchordrq.io_rerank to 'read_buffer'; + +query I +SELECT id FROM t_expr ORDER BY ARRAY[id::real, id::real, id::real]::vector(3) <-> '[1.9, 1.9, 1.9]' limit 9; +---- +2 +1 +3 +4 +5 +6 +7 +8 +9 + +query I +SELECT ARRAY(SELECT id FROM t_expr WHERE (id <= 5 OR id % 2 = 1) OR e >= 2000 ORDER BY ARRAY[id::real, id::real, id::real]::vector(3) <-> q LIMIT 9) FROM (VALUES ('[1.9,1.99,1.999]'::vector, 1999), ('[2.1,2.11,2.111]', 2111)) AS t(q, e); +---- +{2,1,3,4,5,7,9,11,13} +{2,3,1,4,5,6,7,8,9} + +# heap rerank + simple prefetch + postfilter +statement ok +SET vchordrq.prefilter to off; + +statement ok +SET vchordrq.io_search to 'prefetch_buffer'; + +query I +SELECT id FROM t_expr ORDER BY ARRAY[id::real, id::real, id::real]::vector(3) <-> '[1.9, 1.9, 1.9]' limit 9; +---- +2 +1 +3 +4 +5 +6 +7 +8 +9 + +query I +SELECT ARRAY(SELECT id FROM t_expr WHERE (id <= 5 OR id % 2 = 1) OR e >= 2000 ORDER BY ARRAY[id::real, id::real, id::real]::vector(3) <-> q LIMIT 9) FROM (VALUES ('[1.9,1.99,1.999]'::vector, 1999), ('[2.1,2.11,2.111]', 2111)) AS t(q, e); +---- +{2,1,3,4,5,7,9,11,13} +{2,3,1,4,5,6,7,8,9} + +# heap rerank + no prefetch + prefilter +statement ok +SET vchordrq.prefilter to on; + +statement ok +SET vchordrq.io_rerank to 'read_buffer'; + +query I +SELECT ARRAY(SELECT id FROM t_expr WHERE (id <= 5 OR id % 2 = 1) OR e >= 2000 ORDER BY ARRAY[id::real, id::real, id::real]::vector(3) <-> q LIMIT 9) FROM (VALUES ('[1.9,1.99,1.999]'::vector, 1999), ('[2.1,2.11,2.111]', 2111)) AS t(q, e); +---- +{2,1,3,4,5,7,9,11,13} +{2,3,1,4,5,6,7,8,9} + +# heap rerank + simple prefetch + prefilter +statement ok +SET vchordrq.prefilter to on; + +statement ok +SET vchordrq.io_search to 'prefetch_buffer'; + +query I +SELECT ARRAY(SELECT id FROM t_expr WHERE (id <= 5 OR id % 2 = 1) OR e >= 2000 ORDER BY ARRAY[id::real, id::real, id::real]::vector(3) <-> q LIMIT 9) FROM (VALUES ('[1.9,1.99,1.999]'::vector, 1999), ('[2.1,2.11,2.111]', 2111)) AS t(q, e); +---- +{2,1,3,4,5,7,9,11,13} +{2,3,1,4,5,6,7,8,9} + +statement ok +DROP TABLE t_expr; diff --git a/tests/pg17/filter_rerank_in_index.slt b/tests/pg17/filter_rerank_in_index.slt new file mode 100644 index 00000000..3b0e9b10 --- /dev/null +++ b/tests/pg17/filter_rerank_in_index.slt @@ -0,0 +1,101 @@ +statement ok +SET enable_seqscan = off; + +statement ok +CREATE TABLE t_expr (id integer); + +statement ok +INSERT INTO t_expr (id) SELECT id FROM generate_series(1, 10000) s(id); + +statement ok +CREATE INDEX ON t_expr USING vchordrq ((ARRAY[id::real, id::real, id::real]::vector(3)) vector_l2_ops) +WITH (options = $$ +residual_quantization = false +rerank_in_table = false +[build.internal] +lists = [] +$$); + +statement ok +SET vchordrq.probes = ''; + +# non-heap rerank + no prefetch + postfilter +statement ok +SET vchordrq.prefilter to off; + +statement ok +SET vchordrq.io_rerank to 'read_buffer'; + +query I +SELECT id FROM t_expr ORDER BY ARRAY[id::real, id::real, id::real]::vector(3) <-> '[1.9, 1.9, 1.9]' limit 9; +---- +2 +1 +3 +4 +5 +6 +7 +8 +9 + +query I +SELECT ARRAY(SELECT id FROM t_expr WHERE (id <= 5 OR id % 2 = 1) OR e >= 2000 ORDER BY ARRAY[id::real, id::real, id::real]::vector(3) <-> q LIMIT 9) FROM (VALUES ('[1.9,1.99,1.999]'::vector, 1999), ('[2.1,2.11,2.111]', 2111)) AS t(q, e); +---- +{2,1,3,4,5,7,9,11,13} +{2,3,1,4,5,6,7,8,9} + +# non-heap rerank + stream prefetch + postfilter +statement ok +SET vchordrq.prefilter to off; + +statement ok +SET vchordrq.io_search to 'read_stream'; + +query I +SELECT id FROM t_expr ORDER BY ARRAY[id::real, id::real, id::real]::vector(3) <-> '[1.9, 1.9, 1.9]' limit 9; +---- +2 +1 +3 +4 +5 +6 +7 +8 +9 + +query I +SELECT ARRAY(SELECT id FROM t_expr WHERE (id <= 5 OR id % 2 = 1) OR e >= 2000 ORDER BY ARRAY[id::real, id::real, id::real]::vector(3) <-> q LIMIT 9) FROM (VALUES ('[1.9,1.99,1.999]'::vector, 1999), ('[2.1,2.11,2.111]', 2111)) AS t(q, e); +---- +{2,1,3,4,5,7,9,11,13} +{2,3,1,4,5,6,7,8,9} + +# non-heap rerank + no prefetch + prefilter +statement ok +SET vchordrq.prefilter to on; + +statement ok +SET vchordrq.io_rerank to 'read_buffer'; + +query I +SELECT ARRAY(SELECT id FROM t_expr WHERE (id <= 5 OR id % 2 = 1) OR e >= 2000 ORDER BY ARRAY[id::real, id::real, id::real]::vector(3) <-> q LIMIT 9) FROM (VALUES ('[1.9,1.99,1.999]'::vector, 1999), ('[2.1,2.11,2.111]', 2111)) AS t(q, e); +---- +{2,1,3,4,5,7,9,11,13} +{2,3,1,4,5,6,7,8,9} + +# non-heap rerank + stream prefetch + prefilter +statement ok +SET vchordrq.prefilter to on; + +statement ok +SET vchordrq.io_search to 'read_stream'; + +query I +SELECT ARRAY(SELECT id FROM t_expr WHERE (id <= 5 OR id % 2 = 1) OR e >= 2000 ORDER BY ARRAY[id::real, id::real, id::real]::vector(3) <-> q LIMIT 9) FROM (VALUES ('[1.9,1.99,1.999]'::vector, 1999), ('[2.1,2.11,2.111]', 2111)) AS t(q, e); +---- +{2,1,3,4,5,7,9,11,13} +{2,3,1,4,5,6,7,8,9} + +statement ok +DROP TABLE t_expr; diff --git a/tests/pg17/filter_rerank_in_table.slt b/tests/pg17/filter_rerank_in_table.slt new file mode 100644 index 00000000..19f0bf60 --- /dev/null +++ b/tests/pg17/filter_rerank_in_table.slt @@ -0,0 +1,101 @@ +statement ok +SET enable_seqscan = off; + +statement ok +CREATE TABLE t_expr (id integer); + +statement ok +INSERT INTO t_expr (id) SELECT id FROM generate_series(1, 10000) s(id); + +statement ok +CREATE INDEX ON t_expr USING vchordrq ((ARRAY[id::real, id::real, id::real]::vector(3)) vector_l2_ops) +WITH (options = $$ +residual_quantization = false +rerank_in_table = true +[build.internal] +lists = [] +$$); + +statement ok +SET vchordrq.probes = ''; + +# heap rerank + no prefetch + postfilter +statement ok +SET vchordrq.prefilter to off; + +statement ok +SET vchordrq.io_rerank to 'read_buffer'; + +query I +SELECT id FROM t_expr ORDER BY ARRAY[id::real, id::real, id::real]::vector(3) <-> '[1.9, 1.9, 1.9]' limit 9; +---- +2 +1 +3 +4 +5 +6 +7 +8 +9 + +query I +SELECT ARRAY(SELECT id FROM t_expr WHERE (id <= 5 OR id % 2 = 1) OR e >= 2000 ORDER BY ARRAY[id::real, id::real, id::real]::vector(3) <-> q LIMIT 9) FROM (VALUES ('[1.9,1.99,1.999]'::vector, 1999), ('[2.1,2.11,2.111]', 2111)) AS t(q, e); +---- +{2,1,3,4,5,7,9,11,13} +{2,3,1,4,5,6,7,8,9} + +# heap rerank + stream prefetch + postfilter +statement ok +SET vchordrq.prefilter to off; + +statement ok +SET vchordrq.io_search to 'read_stream'; + +query I +SELECT id FROM t_expr ORDER BY ARRAY[id::real, id::real, id::real]::vector(3) <-> '[1.9, 1.9, 1.9]' limit 9; +---- +2 +1 +3 +4 +5 +6 +7 +8 +9 + +query I +SELECT ARRAY(SELECT id FROM t_expr WHERE (id <= 5 OR id % 2 = 1) OR e >= 2000 ORDER BY ARRAY[id::real, id::real, id::real]::vector(3) <-> q LIMIT 9) FROM (VALUES ('[1.9,1.99,1.999]'::vector, 1999), ('[2.1,2.11,2.111]', 2111)) AS t(q, e); +---- +{2,1,3,4,5,7,9,11,13} +{2,3,1,4,5,6,7,8,9} + +# heap rerank + no prefetch + prefilter +statement ok +SET vchordrq.prefilter to on; + +statement ok +SET vchordrq.io_rerank to 'read_buffer'; + +query I +SELECT ARRAY(SELECT id FROM t_expr WHERE (id <= 5 OR id % 2 = 1) OR e >= 2000 ORDER BY ARRAY[id::real, id::real, id::real]::vector(3) <-> q LIMIT 9) FROM (VALUES ('[1.9,1.99,1.999]'::vector, 1999), ('[2.1,2.11,2.111]', 2111)) AS t(q, e); +---- +{2,1,3,4,5,7,9,11,13} +{2,3,1,4,5,6,7,8,9} + +# heap rerank + stream prefetch + prefilter +statement ok +SET vchordrq.prefilter to on; + +statement ok +SET vchordrq.io_search to 'read_stream'; + +query I +SELECT ARRAY(SELECT id FROM t_expr WHERE (id <= 5 OR id % 2 = 1) OR e >= 2000 ORDER BY ARRAY[id::real, id::real, id::real]::vector(3) <-> q LIMIT 9) FROM (VALUES ('[1.9,1.99,1.999]'::vector, 1999), ('[2.1,2.11,2.111]', 2111)) AS t(q, e); +---- +{2,1,3,4,5,7,9,11,13} +{2,3,1,4,5,6,7,8,9} + +statement ok +DROP TABLE t_expr; From 6f8ba9a106373d9657d322135a383fd7be994ffd Mon Sep 17 00:00:00 2001 From: cutecutecat Date: Tue, 13 May 2025 20:12:04 +0800 Subject: [PATCH 2/2] fix by comments Signed-off-by: cutecutecat --- .github/workflows/check.yml | 4 +- crates/algorithm/src/fast_heap.rs | 4 - crates/algorithm/src/lib.rs | 55 +-- src/index/am/mod.rs | 42 +-- src/index/scanners/default.rs | 585 ++++++++++++++++++++++-------- src/index/scanners/maxsim.rs | 130 ++++++- src/index/scanners/mod.rs | 6 +- tests/general/external_build.slt | 3 - tests/general/partition.slt | 3 - tests/general/rerank_in_index.slt | 3 - tests/general/rerank_in_table.slt | 3 - 11 files changed, 567 insertions(+), 271 deletions(-) diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index 97083aa0..2978320b 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -237,14 +237,14 @@ jobs: psql -c 'CREATE EXTENSION IF NOT EXISTS vchord CASCADE;' sqllogictest --db $USER --user $USER './tests/general/*.slt' - - name: Sqllogictest(PostgrSQL 17 features) + - name: Sqllogictest(PostgreSQL 17 features) if: matrix.version == '17' run: | sudo systemctl start postgresql psql -c 'CREATE EXTENSION IF NOT EXISTS vchord CASCADE;' sqllogictest --db $USER --user $USER './tests/pg17/*.slt' - - name: Sqllogictest(PostgrSQL 16 features) + - name: Sqllogictest(PostgreSQL 16 features) if: matrix.version == '16' run: | sudo systemctl start postgresql diff --git a/crates/algorithm/src/fast_heap.rs b/crates/algorithm/src/fast_heap.rs index fd7db993..f70fd176 100644 --- a/crates/algorithm/src/fast_heap.rs +++ b/crates/algorithm/src/fast_heap.rs @@ -82,10 +82,6 @@ impl Sequence for FastHeap { fn next(&mut self) -> Option { self.pop() } - fn next_if(&mut self, predicate: impl FnOnce(&T) -> bool) -> Option { - let first = self.peek()?; - if predicate(first) { self.pop() } else { None } - } fn into_inner(self) -> Self::Inner { match self { FastHeap::Sorted(sort_heap) => sort_heap.inner.into_iter(), diff --git a/crates/algorithm/src/lib.rs b/crates/algorithm/src/lib.rs index b13af61d..e06b9a11 100644 --- a/crates/algorithm/src/lib.rs +++ b/crates/algorithm/src/lib.rs @@ -221,7 +221,7 @@ impl<'b, T, A, B> Fetch for (T, AlwaysEqual<&'b mut (A, B, &'b mut [u32])>) { pub struct Filter { pub iter: S, - pub predicate: P, + pub filter: P, } impl bool> Sequence for Filter { @@ -229,29 +229,21 @@ impl bool> Sequence for Filter { type Inner = S::Inner; fn peek(&mut self) -> Option<&Self::Item> { - self.iter.peek() - } - - fn next(&mut self) -> Option { loop { let item = self.iter.peek()?; - if (self.predicate)(item) { - return self.iter.next(); + if (self.filter)(item) { + return self.iter.peek(); } else { self.iter.next(); continue; } } } - - fn next_if(&mut self, predicate: impl FnOnce(&Self::Item) -> bool) -> Option { + fn next(&mut self) -> Option { loop { let item = self.iter.peek()?; - if (self.predicate)(item) { - return match predicate(item) { - true => self.iter.next(), - false => None, - }; + if (self.filter)(item) { + return self.iter.next(); } else { self.iter.next(); continue; @@ -269,18 +261,11 @@ pub trait Sequence { type Inner: Iterator; fn peek(&mut self) -> Option<&Self::Item>; fn next(&mut self) -> Option; - fn next_if(&mut self, predicate: impl FnOnce(&Self::Item) -> bool) -> Option; - fn into_inner(self) -> Self::Inner; - fn filter

(self, predicate: P) -> Filter - where - Self: Sized, - P: FnMut(&Self::Item) -> bool, - { - Filter { - iter: self, - predicate, - } + fn next_if(&mut self, predicate: impl FnOnce(&Self::Item) -> bool) -> Option { + let peek = self.peek()?; + if predicate(peek) { self.next() } else { None } } + fn into_inner(self) -> Self::Inner; } impl Sequence for BinaryHeap { @@ -292,10 +277,6 @@ impl Sequence for BinaryHeap { fn next(&mut self) -> Option { self.pop() } - fn next_if(&mut self, predicate: impl FnOnce(&T) -> bool) -> Option { - let peek = self.peek()?; - if predicate(peek) { self.pop() } else { None } - } fn into_inner(self) -> Self::Inner { self.into_vec().into_iter() } @@ -310,25 +291,15 @@ impl Sequence for Peekable { fn next(&mut self) -> Option { Iterator::next(self) } - fn next_if(&mut self, predicate: impl FnOnce(&I::Item) -> bool) -> Option { - Peekable::next_if(self, predicate) - } fn into_inner(self) -> Self::Inner { self } } -pub fn seq_filter( - heap: impl Sequence, - prefilter: bool, - filter: F, -) -> impl Sequence +pub fn seq_filter(heap: impl Sequence, filter: F) -> impl Sequence where - F: Fn(&T) -> bool, + F: FnMut(&T) -> bool, T: Ord, { - heap.filter(move |t| match prefilter { - true => filter(t), - false => true, - }) + Filter { iter: heap, filter } } diff --git a/src/index/am/mod.rs b/src/index/am/mod.rs index cc2d89a0..2e1dbaf1 100644 --- a/src/index/am/mod.rs +++ b/src/index/am/mod.rs @@ -15,7 +15,6 @@ pub mod am_build; use super::algorithm::BumpAlloc; -use super::gucs::prefilter; use crate::index::gucs; use crate::index::lazy_cell::LazyCell; use crate::index::opclass::{Opfamily, opfamily}; @@ -571,13 +570,6 @@ impl Drop for HeapFetcher { } impl SearchFetcher for HeapFetcher { - /// Fetches data associated with the given `key` passes the filter criteria. - /// - /// # Parameters - /// - `key`: A 3-element array representing the key to be checked. - /// - /// Returns a tuple containing a reference to an array of `Datum` and a reference to an array of `bool`, - /// or `None` if no data is found. fn fetch(&mut self, key: [u16; 3]) -> Option<(&[Datum; 32], &[bool; 32])> { unsafe { let mut ctid = key_to_ctid(key); @@ -588,28 +580,6 @@ impl SearchFetcher for HeapFetcher { if !fetch_row_version(self.heap_relation, &mut ctid, self.snapshot, self.slot) { return None; } - if !self.hack.is_null() && prefilter() { - if let Some(qual) = NonNull::new((*self.hack).ss.ps.qual) { - use pgrx::datum::FromDatum; - use pgrx::memcxt::PgMemoryContexts; - assert!(qual.as_ref().flags & pgrx::pg_sys::EEO_FLAG_IS_QUAL as u8 != 0); - let evalfunc = qual.as_ref().evalfunc.expect("no evalfunc for qual"); - if !(*self.hack).ss.ps.ps_ExprContext.is_null() { - let econtext = (*self.hack).ss.ps.ps_ExprContext; - (*econtext).ecxt_scantuple = self.slot; - pgrx::pg_sys::MemoryContextReset((*econtext).ecxt_per_tuple_memory); - let result = PgMemoryContexts::For((*econtext).ecxt_per_tuple_memory) - .switch_to(|_| { - let mut is_null = true; - let datum = evalfunc(qual.as_ptr(), econtext, &mut is_null); - bool::from_datum(datum, is_null) - }); - if result != Some(true) { - return None; - } - } - } - } (*self.econtext).ecxt_scantuple = self.slot; pgrx::pg_sys::MemoryContextReset((*self.econtext).ecxt_per_tuple_memory); pgrx::pg_sys::FormIndexDatum( @@ -623,16 +593,8 @@ impl SearchFetcher for HeapFetcher { } } - /// Determines whether the given `key` passes the filter criteria. - /// - /// # Parameters - /// - `key`: A 3-element array representing the key to be checked. - /// - /// # Returns - /// - `true` if the key satisfies the filter conditions. - /// - `false` otherwise. - fn filter(&self, key: [u16; 3]) -> bool { - if self.hack.is_null() || !prefilter() { + fn filter(&mut self, key: [u16; 3]) -> bool { + if self.hack.is_null() { return true; } unsafe { diff --git a/src/index/scanners/default.rs b/src/index/scanners/default.rs index dbdd209b..021d514c 100644 --- a/src/index/scanners/default.rs +++ b/src/index/scanners/default.rs @@ -17,13 +17,10 @@ use crate::index::algorithm::*; use crate::index::am::pointer_to_kv; use crate::index::gucs::prefilter; use crate::index::opclass::{Opfamily, Sphere}; -use algorithm::operator::{Dot, L2, Op, Operator}; +use algorithm::operator::{Dot, L2}; use algorithm::types::{DistanceKind, OwnedVector, VectorKind}; use algorithm::*; -use always_equal::AlwaysEqual; -use distance::Distance; use half::f16; -use std::cmp::Reverse; use std::collections::BinaryHeap; use std::num::NonZero; use vector::VectorOwned; @@ -71,7 +68,7 @@ impl SearchBuilder for DefaultBuilder { self, index: &'a R, options: SearchOptions, - fetcher: impl SearchFetcher + 'a, + mut fetcher: impl SearchFetcher + 'a, bump: &'a impl Bump, ) -> Box + 'a> where @@ -108,6 +105,7 @@ impl SearchBuilder for DefaultBuilder { let iter: Box)>> = match (opfamily.vector_kind(), opfamily.distance_kind()) { (VectorKind::Vecf32, DistanceKind::L2) => { + type Op = operator::Op, L2>; let original_vector = if let OwnedVector::Vecf32(vector) = vector { vector } else { @@ -115,7 +113,7 @@ impl SearchBuilder for DefaultBuilder { }; let vector = RandomProject::project(original_vector.as_borrowed()); let results = match options.io_search { - Io::Plain => default_search::<_, Op, L2>>( + Io::Plain => default_search::<_, Op>( index, vector.clone(), options.probes, @@ -124,7 +122,7 @@ impl SearchBuilder for DefaultBuilder { make_h1_plain_prefetcher, make_h0_plain_prefetcher, ), - Io::Simple => default_search::<_, Op, L2>>( + Io::Simple => default_search::<_, Op>( index, vector.clone(), options.probes, @@ -133,7 +131,7 @@ impl SearchBuilder for DefaultBuilder { make_h1_plain_prefetcher, make_h0_simple_prefetcher, ), - Io::Stream => default_search::<_, Op, L2>>( + Io::Stream => default_search::<_, Op>( index, vector.clone(), options.probes, @@ -144,25 +142,114 @@ impl SearchBuilder for DefaultBuilder { ), }; let method = how(index); - let vector_insider = |vector: OwnedVector| { - if let OwnedVector::Vecf32(v) = vector { - v - } else { - unreachable!() + match (method, options.io_rerank, prefilter()) { + (RerankMethod::Index, Io::Plain, true) => { + let seq = seq_filter(BinaryHeap::from(results), move |key| { + fetcher.filter(pointer_to_kv(key.1.0.0).0) + }); + let prefetcher = PlainPrefetcher::<_, _>::new(index, seq); + Box::new(rerank_index::(original_vector, prefetcher).map( + move |(distance, payload)| (opfamily.output(distance), payload), + )) } - }; - rerank_wrapper::, L2>, _, _>( - original_vector, - index, - opfamily, - fetcher, - results, - method, - options.io_rerank, - vector_insider, - ) + (RerankMethod::Index, Io::Plain, false) => { + let prefetcher = + PlainPrefetcher::<_, BinaryHeap<_>>::new(index, results.into()); + Box::new(rerank_index::(original_vector, prefetcher).map( + move |(distance, payload)| (opfamily.output(distance), payload), + )) + } + (RerankMethod::Index, Io::Simple, true) => { + let seq = seq_filter(BinaryHeap::from(results), move |key| { + fetcher.filter(pointer_to_kv(key.1.0.0).0) + }); + let prefetcher = SimplePrefetcher::<'a, R, _>::new(index, seq); + Box::new(rerank_index::(original_vector, prefetcher).map( + move |(distance, payload)| (opfamily.output(distance), payload), + )) + } + (RerankMethod::Index, Io::Simple, false) => { + let prefetcher = SimplePrefetcher::<'a, R, BinaryHeap<_>>::new( + index, + results.into(), + ); + Box::new(rerank_index::(original_vector, prefetcher).map( + move |(distance, payload)| (opfamily.output(distance), payload), + )) + } + (RerankMethod::Index, Io::Stream, true) => { + let seq = seq_filter(BinaryHeap::from(results), move |key| { + fetcher.filter(pointer_to_kv(key.1.0.0).0) + }); + let prefetcher = + StreamPrefetcher::<_, _>::new(index, seq, Hints::default()); + Box::new(rerank_index::(original_vector, prefetcher).map( + move |(distance, payload)| (opfamily.output(distance), payload), + )) + } + (RerankMethod::Index, Io::Stream, false) => { + let prefetcher = StreamPrefetcher::<_, BinaryHeap<_>>::new( + index, + results.into(), + Hints::default(), + ); + Box::new(rerank_index::(original_vector, prefetcher).map( + move |(distance, payload)| (opfamily.output(distance), payload), + )) + } + (RerankMethod::Heap, _, true) => { + let fetch = move |payload| { + let (key, _) = pointer_to_kv(payload); + if !fetcher.filter(key) { + return None; + } + let (datums, is_nulls) = fetcher.fetch(key)?; + let datum = (!is_nulls[0]).then_some(datums[0]); + let maybe_vector = + unsafe { datum.and_then(|x| opfamily.input_vector(x)) }; + let raw = if let OwnedVector::Vecf32(vector) = maybe_vector.unwrap() + { + vector + } else { + unreachable!() + }; + Some(raw) + }; + let prefetcher = + PlainPrefetcher::<_, BinaryHeap<_>>::new(index, results.into()); + Box::new( + rerank_heap::(original_vector, prefetcher, fetch).map( + move |(distance, payload)| (opfamily.output(distance), payload), + ), + ) + } + (RerankMethod::Heap, _, false) => { + let fetch = move |payload| { + let (key, _) = pointer_to_kv(payload); + let (datums, is_nulls) = fetcher.fetch(key)?; + let datum = (!is_nulls[0]).then_some(datums[0]); + let maybe_vector = + unsafe { datum.and_then(|x| opfamily.input_vector(x)) }; + let raw = if let OwnedVector::Vecf32(vector) = maybe_vector.unwrap() + { + vector + } else { + unreachable!() + }; + Some(raw) + }; + let prefetcher = + PlainPrefetcher::<_, BinaryHeap<_>>::new(index, results.into()); + Box::new( + rerank_heap::(original_vector, prefetcher, fetch).map( + move |(distance, payload)| (opfamily.output(distance), payload), + ), + ) + } + } } (VectorKind::Vecf32, DistanceKind::Dot) => { + type Op = operator::Op, Dot>; let original_vector = if let OwnedVector::Vecf32(vector) = vector { vector } else { @@ -170,7 +257,7 @@ impl SearchBuilder for DefaultBuilder { }; let vector = RandomProject::project(original_vector.as_borrowed()); let results = match options.io_search { - Io::Plain => default_search::<_, Op, Dot>>( + Io::Plain => default_search::<_, Op>( index, vector.clone(), options.probes, @@ -179,7 +266,7 @@ impl SearchBuilder for DefaultBuilder { make_h1_plain_prefetcher, make_h0_plain_prefetcher, ), - Io::Simple => default_search::<_, Op, Dot>>( + Io::Simple => default_search::<_, Op>( index, vector.clone(), options.probes, @@ -188,7 +275,7 @@ impl SearchBuilder for DefaultBuilder { make_h1_plain_prefetcher, make_h0_simple_prefetcher, ), - Io::Stream => default_search::<_, Op, Dot>>( + Io::Stream => default_search::<_, Op>( index, vector.clone(), options.probes, @@ -199,25 +286,114 @@ impl SearchBuilder for DefaultBuilder { ), }; let method = how(index); - let vector_insider = |vector: OwnedVector| { - if let OwnedVector::Vecf32(v) = vector { - v - } else { - unreachable!() + match (method, options.io_rerank, prefilter()) { + (RerankMethod::Index, Io::Plain, true) => { + let seq = seq_filter(BinaryHeap::from(results), move |key| { + fetcher.filter(pointer_to_kv(key.1.0.0).0) + }); + let prefetcher = PlainPrefetcher::<_, _>::new(index, seq); + Box::new(rerank_index::(original_vector, prefetcher).map( + move |(distance, payload)| (opfamily.output(distance), payload), + )) } - }; - rerank_wrapper::, Dot>, _, _>( - original_vector, - index, - opfamily, - fetcher, - results, - method, - options.io_rerank, - vector_insider, - ) + (RerankMethod::Index, Io::Plain, false) => { + let prefetcher = + PlainPrefetcher::<_, BinaryHeap<_>>::new(index, results.into()); + Box::new(rerank_index::(original_vector, prefetcher).map( + move |(distance, payload)| (opfamily.output(distance), payload), + )) + } + (RerankMethod::Index, Io::Simple, true) => { + let seq = seq_filter(BinaryHeap::from(results), move |key| { + fetcher.filter(pointer_to_kv(key.1.0.0).0) + }); + let prefetcher = SimplePrefetcher::<'a, R, _>::new(index, seq); + Box::new(rerank_index::(original_vector, prefetcher).map( + move |(distance, payload)| (opfamily.output(distance), payload), + )) + } + (RerankMethod::Index, Io::Simple, false) => { + let prefetcher = SimplePrefetcher::<'a, R, BinaryHeap<_>>::new( + index, + results.into(), + ); + Box::new(rerank_index::(original_vector, prefetcher).map( + move |(distance, payload)| (opfamily.output(distance), payload), + )) + } + (RerankMethod::Index, Io::Stream, true) => { + let seq = seq_filter(BinaryHeap::from(results), move |key| { + fetcher.filter(pointer_to_kv(key.1.0.0).0) + }); + let prefetcher = + StreamPrefetcher::<_, _>::new(index, seq, Hints::default()); + Box::new(rerank_index::(original_vector, prefetcher).map( + move |(distance, payload)| (opfamily.output(distance), payload), + )) + } + (RerankMethod::Index, Io::Stream, false) => { + let prefetcher = StreamPrefetcher::<_, BinaryHeap<_>>::new( + index, + results.into(), + Hints::default(), + ); + Box::new(rerank_index::(original_vector, prefetcher).map( + move |(distance, payload)| (opfamily.output(distance), payload), + )) + } + (RerankMethod::Heap, _, true) => { + let fetch = move |payload| { + let (key, _) = pointer_to_kv(payload); + if !fetcher.filter(key) { + return None; + } + let (datums, is_nulls) = fetcher.fetch(key)?; + let datum = (!is_nulls[0]).then_some(datums[0]); + let maybe_vector = + unsafe { datum.and_then(|x| opfamily.input_vector(x)) }; + let raw = if let OwnedVector::Vecf32(vector) = maybe_vector.unwrap() + { + vector + } else { + unreachable!() + }; + Some(raw) + }; + let prefetcher = + PlainPrefetcher::<_, BinaryHeap<_>>::new(index, results.into()); + Box::new( + rerank_heap::(original_vector, prefetcher, fetch).map( + move |(distance, payload)| (opfamily.output(distance), payload), + ), + ) + } + (RerankMethod::Heap, _, false) => { + let fetch = move |payload| { + let (key, _) = pointer_to_kv(payload); + let (datums, is_nulls) = fetcher.fetch(key)?; + let datum = (!is_nulls[0]).then_some(datums[0]); + let maybe_vector = + unsafe { datum.and_then(|x| opfamily.input_vector(x)) }; + let raw = if let OwnedVector::Vecf32(vector) = maybe_vector.unwrap() + { + vector + } else { + unreachable!() + }; + Some(raw) + }; + let prefetcher = + PlainPrefetcher::<_, BinaryHeap<_>>::new(index, results.into()); + Box::new( + rerank_heap::(original_vector, prefetcher, fetch).map( + move |(distance, payload)| (opfamily.output(distance), payload), + ), + ) + } + } } (VectorKind::Vecf16, DistanceKind::L2) => { + type Op = operator::Op, L2>; let original_vector = if let OwnedVector::Vecf16(vector) = vector { vector } else { @@ -225,7 +401,7 @@ impl SearchBuilder for DefaultBuilder { }; let vector = RandomProject::project(original_vector.as_borrowed()); let results = match options.io_search { - Io::Plain => default_search::<_, Op, L2>>( + Io::Plain => default_search::<_, Op>( index, vector.clone(), options.probes, @@ -234,7 +410,7 @@ impl SearchBuilder for DefaultBuilder { make_h1_plain_prefetcher, make_h0_plain_prefetcher, ), - Io::Simple => default_search::<_, Op, L2>>( + Io::Simple => default_search::<_, Op>( index, vector.clone(), options.probes, @@ -243,7 +419,7 @@ impl SearchBuilder for DefaultBuilder { make_h1_plain_prefetcher, make_h0_simple_prefetcher, ), - Io::Stream => default_search::<_, Op, L2>>( + Io::Stream => default_search::<_, Op>( index, vector.clone(), options.probes, @@ -254,25 +430,114 @@ impl SearchBuilder for DefaultBuilder { ), }; let method = how(index); - let vector_insider = |vector: OwnedVector| { - if let OwnedVector::Vecf16(v) = vector { - v - } else { - unreachable!() + match (method, options.io_rerank, prefilter()) { + (RerankMethod::Index, Io::Plain, true) => { + let seq = seq_filter(BinaryHeap::from(results), move |key| { + fetcher.filter(pointer_to_kv(key.1.0.0).0) + }); + let prefetcher = PlainPrefetcher::<_, _>::new(index, seq); + Box::new(rerank_index::(original_vector, prefetcher).map( + move |(distance, payload)| (opfamily.output(distance), payload), + )) } - }; - rerank_wrapper::, L2>, _, _>( - original_vector, - index, - opfamily, - fetcher, - results, - method, - options.io_rerank, - vector_insider, - ) + (RerankMethod::Index, Io::Plain, false) => { + let prefetcher = + PlainPrefetcher::<_, BinaryHeap<_>>::new(index, results.into()); + Box::new(rerank_index::(original_vector, prefetcher).map( + move |(distance, payload)| (opfamily.output(distance), payload), + )) + } + (RerankMethod::Index, Io::Simple, true) => { + let seq = seq_filter(BinaryHeap::from(results), move |key| { + fetcher.filter(pointer_to_kv(key.1.0.0).0) + }); + let prefetcher = SimplePrefetcher::<'a, R, _>::new(index, seq); + Box::new(rerank_index::(original_vector, prefetcher).map( + move |(distance, payload)| (opfamily.output(distance), payload), + )) + } + (RerankMethod::Index, Io::Simple, false) => { + let prefetcher = SimplePrefetcher::<'a, R, BinaryHeap<_>>::new( + index, + results.into(), + ); + Box::new(rerank_index::(original_vector, prefetcher).map( + move |(distance, payload)| (opfamily.output(distance), payload), + )) + } + (RerankMethod::Index, Io::Stream, true) => { + let seq = seq_filter(BinaryHeap::from(results), move |key| { + fetcher.filter(pointer_to_kv(key.1.0.0).0) + }); + let prefetcher = + StreamPrefetcher::<_, _>::new(index, seq, Hints::default()); + Box::new(rerank_index::(original_vector, prefetcher).map( + move |(distance, payload)| (opfamily.output(distance), payload), + )) + } + (RerankMethod::Index, Io::Stream, false) => { + let prefetcher = StreamPrefetcher::<_, BinaryHeap<_>>::new( + index, + results.into(), + Hints::default(), + ); + Box::new(rerank_index::(original_vector, prefetcher).map( + move |(distance, payload)| (opfamily.output(distance), payload), + )) + } + (RerankMethod::Heap, _, true) => { + let fetch = move |payload| { + let (key, _) = pointer_to_kv(payload); + if !fetcher.filter(key) { + return None; + } + let (datums, is_nulls) = fetcher.fetch(key)?; + let datum = (!is_nulls[0]).then_some(datums[0]); + let maybe_vector = + unsafe { datum.and_then(|x| opfamily.input_vector(x)) }; + let raw = if let OwnedVector::Vecf16(vector) = maybe_vector.unwrap() + { + vector + } else { + unreachable!() + }; + Some(raw) + }; + let prefetcher = + PlainPrefetcher::<_, BinaryHeap<_>>::new(index, results.into()); + Box::new( + rerank_heap::(original_vector, prefetcher, fetch).map( + move |(distance, payload)| (opfamily.output(distance), payload), + ), + ) + } + (RerankMethod::Heap, _, false) => { + let fetch = move |payload| { + let (key, _) = pointer_to_kv(payload); + let (datums, is_nulls) = fetcher.fetch(key)?; + let datum = (!is_nulls[0]).then_some(datums[0]); + let maybe_vector = + unsafe { datum.and_then(|x| opfamily.input_vector(x)) }; + let raw = if let OwnedVector::Vecf16(vector) = maybe_vector.unwrap() + { + vector + } else { + unreachable!() + }; + Some(raw) + }; + let prefetcher = + PlainPrefetcher::<_, BinaryHeap<_>>::new(index, results.into()); + Box::new( + rerank_heap::(original_vector, prefetcher, fetch).map( + move |(distance, payload)| (opfamily.output(distance), payload), + ), + ) + } + } } (VectorKind::Vecf16, DistanceKind::Dot) => { + type Op = operator::Op, Dot>; let original_vector = if let OwnedVector::Vecf16(vector) = vector { vector } else { @@ -280,7 +545,7 @@ impl SearchBuilder for DefaultBuilder { }; let vector = RandomProject::project(original_vector.as_borrowed()); let results = match options.io_search { - Io::Plain => default_search::<_, Op, Dot>>( + Io::Plain => default_search::<_, Op>( index, vector.clone(), options.probes, @@ -289,7 +554,7 @@ impl SearchBuilder for DefaultBuilder { make_h1_plain_prefetcher, make_h0_plain_prefetcher, ), - Io::Simple => default_search::<_, Op, Dot>>( + Io::Simple => default_search::<_, Op>( index, vector.clone(), options.probes, @@ -298,7 +563,7 @@ impl SearchBuilder for DefaultBuilder { make_h1_plain_prefetcher, make_h0_simple_prefetcher, ), - Io::Stream => default_search::<_, Op, Dot>>( + Io::Stream => default_search::<_, Op>( index, vector.clone(), options.probes, @@ -309,23 +574,111 @@ impl SearchBuilder for DefaultBuilder { ), }; let method = how(index); - let vector_insider = |vector: OwnedVector| { - if let OwnedVector::Vecf16(v) = vector { - v - } else { - unreachable!() + match (method, options.io_rerank, prefilter()) { + (RerankMethod::Index, Io::Plain, true) => { + let seq = seq_filter(BinaryHeap::from(results), move |key| { + fetcher.filter(pointer_to_kv(key.1.0.0).0) + }); + let prefetcher = PlainPrefetcher::<_, _>::new(index, seq); + Box::new(rerank_index::(original_vector, prefetcher).map( + move |(distance, payload)| (opfamily.output(distance), payload), + )) } - }; - rerank_wrapper::, Dot>, _, _>( - original_vector, - index, - opfamily, - fetcher, - results, - method, - options.io_rerank, - vector_insider, - ) + (RerankMethod::Index, Io::Plain, false) => { + let prefetcher = + PlainPrefetcher::<_, BinaryHeap<_>>::new(index, results.into()); + Box::new(rerank_index::(original_vector, prefetcher).map( + move |(distance, payload)| (opfamily.output(distance), payload), + )) + } + (RerankMethod::Index, Io::Simple, true) => { + let seq = seq_filter(BinaryHeap::from(results), move |key| { + fetcher.filter(pointer_to_kv(key.1.0.0).0) + }); + let prefetcher = SimplePrefetcher::<'a, R, _>::new(index, seq); + Box::new(rerank_index::(original_vector, prefetcher).map( + move |(distance, payload)| (opfamily.output(distance), payload), + )) + } + (RerankMethod::Index, Io::Simple, false) => { + let prefetcher = SimplePrefetcher::<'a, R, BinaryHeap<_>>::new( + index, + results.into(), + ); + Box::new(rerank_index::(original_vector, prefetcher).map( + move |(distance, payload)| (opfamily.output(distance), payload), + )) + } + (RerankMethod::Index, Io::Stream, true) => { + let seq = seq_filter(BinaryHeap::from(results), move |key| { + fetcher.filter(pointer_to_kv(key.1.0.0).0) + }); + let prefetcher = + StreamPrefetcher::<_, _>::new(index, seq, Hints::default()); + Box::new(rerank_index::(original_vector, prefetcher).map( + move |(distance, payload)| (opfamily.output(distance), payload), + )) + } + (RerankMethod::Index, Io::Stream, false) => { + let prefetcher = StreamPrefetcher::<_, BinaryHeap<_>>::new( + index, + results.into(), + Hints::default(), + ); + Box::new(rerank_index::(original_vector, prefetcher).map( + move |(distance, payload)| (opfamily.output(distance), payload), + )) + } + (RerankMethod::Heap, _, true) => { + let fetch = move |payload| { + let (key, _) = pointer_to_kv(payload); + if !fetcher.filter(key) { + return None; + } + let (datums, is_nulls) = fetcher.fetch(key)?; + let datum = (!is_nulls[0]).then_some(datums[0]); + let maybe_vector = + unsafe { datum.and_then(|x| opfamily.input_vector(x)) }; + let raw = if let OwnedVector::Vecf16(vector) = maybe_vector.unwrap() + { + vector + } else { + unreachable!() + }; + Some(raw) + }; + let prefetcher = + PlainPrefetcher::<_, BinaryHeap<_>>::new(index, results.into()); + Box::new( + rerank_heap::(original_vector, prefetcher, fetch).map( + move |(distance, payload)| (opfamily.output(distance), payload), + ), + ) + } + (RerankMethod::Heap, _, false) => { + let fetch = move |payload| { + let (key, _) = pointer_to_kv(payload); + let (datums, is_nulls) = fetcher.fetch(key)?; + let datum = (!is_nulls[0]).then_some(datums[0]); + let maybe_vector = + unsafe { datum.and_then(|x| opfamily.input_vector(x)) }; + let raw = if let OwnedVector::Vecf16(vector) = maybe_vector.unwrap() + { + vector + } else { + unreachable!() + }; + Some(raw) + }; + let prefetcher = + PlainPrefetcher::<_, BinaryHeap<_>>::new(index, results.into()); + Box::new( + rerank_heap::(original_vector, prefetcher, fetch).map( + move |(distance, payload)| (opfamily.output(distance), payload), + ), + ) + } + } } }; let iter = if let Some(threshold) = threshold { @@ -344,73 +697,3 @@ impl SearchBuilder for DefaultBuilder { })) } } - -type SeqElement<'a> = ( - (Reverse, AlwaysEqual<()>), - AlwaysEqual<&'a mut (NonZero, u16, &'a mut [u32])>, -); - -#[allow(clippy::too_many_arguments)] -fn rerank_wrapper<'a, O: Operator>, R, T>( - vector: O::Vector, - index: &'a R, - opfamily: Opfamily, - mut fetcher: impl SearchFetcher + 'a, - results: Vec>, - method: RerankMethod, - io_rerank: Io, - vector_insider: impl Fn(OwnedVector) -> VectOwned + 'a, -) -> Box)> + 'a> -where - R: RelationRead + RelationPrefetch + RelationReadStream, -{ - match (method, io_rerank) { - (RerankMethod::Index, Io::Plain) => { - let seq = seq_filter(BinaryHeap::from(results), prefilter(), move |key| { - let (key, _) = pointer_to_kv(key.1.0.0); - fetcher.filter(key) - }); - let prefetcher = PlainPrefetcher::<_, _>::new(index, seq); - Box::new( - rerank_index::(vector, prefetcher) - .map(move |(distance, payload)| (opfamily.output(distance), payload)), - ) - } - (RerankMethod::Index, Io::Simple) => { - let seq = seq_filter(BinaryHeap::from(results), prefilter(), move |key| { - let (key, _) = pointer_to_kv(key.1.0.0); - fetcher.filter(key) - }); - let prefetcher = SimplePrefetcher::<'a, R, _>::new(index, seq); - Box::new( - rerank_index::(vector, prefetcher) - .map(move |(distance, payload)| (opfamily.output(distance), payload)), - ) - } - (RerankMethod::Index, Io::Stream) => { - let seq = seq_filter(BinaryHeap::from(results), prefilter(), move |key| { - let (key, _) = pointer_to_kv(key.1.0.0); - fetcher.filter(key) - }); - let prefetcher = StreamPrefetcher::<_, _>::new(index, seq, Hints::default()); - Box::new( - rerank_index::(vector, prefetcher) - .map(move |(distance, payload)| (opfamily.output(distance), payload)), - ) - } - (RerankMethod::Heap, _) => { - let fetch = move |payload| { - let (key, _) = pointer_to_kv(payload); - let (datums, is_nulls) = fetcher.fetch(key)?; - let datum = (!is_nulls[0]).then_some(datums[0]); - let maybe_vector = unsafe { datum.and_then(|x| opfamily.input_vector(x)) }; - Some(vector_insider(maybe_vector.unwrap())) - }; - let prefetcher = PlainPrefetcher::<_, BinaryHeap<_>>::new(index, results.into()); - Box::new( - rerank_heap::(vector, prefetcher, fetch) - .map(move |(distance, payload)| (opfamily.output(distance), payload)), - ) - } - } -} diff --git a/src/index/scanners/maxsim.rs b/src/index/scanners/maxsim.rs index febd5e1a..9c39d800 100644 --- a/src/index/scanners/maxsim.rs +++ b/src/index/scanners/maxsim.rs @@ -61,7 +61,7 @@ impl SearchBuilder for MaxsimBuilder { self, index: &'a R, options: SearchOptions, - fetcher: impl SearchFetcher + 'a, + mut fetcher: impl SearchFetcher + 'a, bump: &'a impl Bump, ) -> Box + 'a> where @@ -155,12 +155,11 @@ impl SearchBuilder for MaxsimBuilder { }; let (mut accu_set, mut rough_set) = (Vec::new(), Vec::new()); if maxsim_refine != 0 && !results.is_empty() { - let seq = seq_filter(BinaryHeap::from(results), prefilter(), |key| { - let (key, _) = pointer_to_kv(key.1.0.0); - fetcher.filter(key) - }); - match options.io_rerank { - Io::Plain => { + match (options.io_rerank, prefilter()) { + (Io::Plain, true) => { + let seq = seq_filter(BinaryHeap::from(results), |key| { + fetcher.filter(pointer_to_kv(key.1.0.0).0) + }); let prefetcher = PlainPrefetcher::<_, _>::new(index, seq); let mut reranker = rerank_index::( original_vector.clone(), @@ -171,7 +170,24 @@ impl SearchBuilder for MaxsimBuilder { accu_set.extend(accu_iter.map(accu_map)); rough_set.extend(rough_iter.into_iter().map(rough_map)); } - Io::Simple => { + (Io::Plain, false) => { + let prefetcher = PlainPrefetcher::<_, _>::new( + index, + BinaryHeap::from(results), + ); + let mut reranker = rerank_index::( + original_vector.clone(), + prefetcher, + ); + accu_set.extend(reranker.by_ref().take(maxsim_refine as _)); + let (rough_iter, accu_iter) = reranker.finish(); + accu_set.extend(accu_iter.map(accu_map)); + rough_set.extend(rough_iter.into_iter().map(rough_map)); + } + (Io::Simple, true) => { + let seq = seq_filter(BinaryHeap::from(results), |key| { + fetcher.filter(pointer_to_kv(key.1.0.0).0) + }); let prefetcher = SimplePrefetcher::<'a, R, _>::new(index, seq); let mut reranker = rerank_index::( original_vector.clone(), @@ -182,7 +198,24 @@ impl SearchBuilder for MaxsimBuilder { accu_set.extend(accu_iter.map(accu_map)); rough_set.extend(rough_iter.into_iter().map(rough_map)); } - Io::Stream => { + (Io::Simple, false) => { + let prefetcher = SimplePrefetcher::<'a, R, _>::new( + index, + BinaryHeap::from(results), + ); + let mut reranker = rerank_index::( + original_vector.clone(), + prefetcher, + ); + accu_set.extend(reranker.by_ref().take(maxsim_refine as _)); + let (rough_iter, accu_iter) = reranker.finish(); + accu_set.extend(accu_iter.map(accu_map)); + rough_set.extend(rough_iter.into_iter().map(rough_map)); + } + (Io::Stream, true) => { + let seq = seq_filter(BinaryHeap::from(results), |key| { + fetcher.filter(pointer_to_kv(key.1.0.0).0) + }); let prefetcher = StreamPrefetcher::<_, _>::new(index, seq, Hints::default()); let mut reranker = rerank_index::( @@ -194,6 +227,21 @@ impl SearchBuilder for MaxsimBuilder { accu_set.extend(accu_iter.map(accu_map)); rough_set.extend(rough_iter.into_iter().map(rough_map)); } + (Io::Stream, false) => { + let prefetcher = StreamPrefetcher::<_, _>::new( + index, + BinaryHeap::from(results), + Hints::default(), + ); + let mut reranker = rerank_index::( + original_vector.clone(), + prefetcher, + ); + accu_set.extend(reranker.by_ref().take(maxsim_refine as _)); + let (rough_iter, accu_iter) = reranker.finish(); + accu_set.extend(accu_iter.map(accu_map)); + rough_set.extend(rough_iter.into_iter().map(rough_map)); + } } } else { let rough_iter = results.into_iter(); @@ -255,12 +303,11 @@ impl SearchBuilder for MaxsimBuilder { }; let (mut accu_set, mut rough_set) = (Vec::new(), Vec::new()); if maxsim_refine != 0 && !results.is_empty() { - let seq = seq_filter(BinaryHeap::from(results), prefilter(), |key| { - let (key, _) = pointer_to_kv(key.1.0.0); - fetcher.filter(key) - }); - match options.io_rerank { - Io::Plain => { + match (options.io_rerank, prefilter()) { + (Io::Plain, true) => { + let seq = seq_filter(BinaryHeap::from(results), |key| { + fetcher.filter(pointer_to_kv(key.1.0.0).0) + }); let prefetcher = PlainPrefetcher::<_, _>::new(index, seq); let mut reranker = rerank_index::( original_vector.clone(), @@ -271,7 +318,24 @@ impl SearchBuilder for MaxsimBuilder { accu_set.extend(accu_iter.map(accu_map)); rough_set.extend(rough_iter.into_iter().map(rough_map)); } - Io::Simple => { + (Io::Plain, false) => { + let prefetcher = PlainPrefetcher::<_, _>::new( + index, + BinaryHeap::from(results), + ); + let mut reranker = rerank_index::( + original_vector.clone(), + prefetcher, + ); + accu_set.extend(reranker.by_ref().take(maxsim_refine as _)); + let (rough_iter, accu_iter) = reranker.finish(); + accu_set.extend(accu_iter.map(accu_map)); + rough_set.extend(rough_iter.into_iter().map(rough_map)); + } + (Io::Simple, true) => { + let seq = seq_filter(BinaryHeap::from(results), |key| { + fetcher.filter(pointer_to_kv(key.1.0.0).0) + }); let prefetcher = SimplePrefetcher::<'a, R, _>::new(index, seq); let mut reranker = rerank_index::( original_vector.clone(), @@ -282,7 +346,24 @@ impl SearchBuilder for MaxsimBuilder { accu_set.extend(accu_iter.map(accu_map)); rough_set.extend(rough_iter.into_iter().map(rough_map)); } - Io::Stream => { + (Io::Simple, false) => { + let prefetcher = SimplePrefetcher::<'a, R, _>::new( + index, + BinaryHeap::from(results), + ); + let mut reranker = rerank_index::( + original_vector.clone(), + prefetcher, + ); + accu_set.extend(reranker.by_ref().take(maxsim_refine as _)); + let (rough_iter, accu_iter) = reranker.finish(); + accu_set.extend(accu_iter.map(accu_map)); + rough_set.extend(rough_iter.into_iter().map(rough_map)); + } + (Io::Stream, true) => { + let seq = seq_filter(BinaryHeap::from(results), |key| { + fetcher.filter(pointer_to_kv(key.1.0.0).0) + }); let prefetcher = StreamPrefetcher::<_, _>::new(index, seq, Hints::default()); let mut reranker = rerank_index::( @@ -294,6 +375,21 @@ impl SearchBuilder for MaxsimBuilder { accu_set.extend(accu_iter.map(accu_map)); rough_set.extend(rough_iter.into_iter().map(rough_map)); } + (Io::Stream, false) => { + let prefetcher = StreamPrefetcher::<_, _>::new( + index, + BinaryHeap::from(results), + Hints::default(), + ); + let mut reranker = rerank_index::( + original_vector.clone(), + prefetcher, + ); + accu_set.extend(reranker.by_ref().take(maxsim_refine as _)); + let (rough_iter, accu_iter) = reranker.finish(); + accu_set.extend(accu_iter.map(accu_map)); + rough_set.extend(rough_iter.into_iter().map(rough_map)); + } } } else { let rough_iter = results.into_iter(); diff --git a/src/index/scanners/mod.rs b/src/index/scanners/mod.rs index aa4a2b5f..16ddd5d7 100644 --- a/src/index/scanners/mod.rs +++ b/src/index/scanners/mod.rs @@ -63,14 +63,14 @@ pub trait SearchBuilder: 'static { pub trait SearchFetcher { fn fetch(&mut self, ctid: [u16; 3]) -> Option<(&[Datum; 32], &[bool; 32])>; - fn filter(&self, key: [u16; 3]) -> bool; + fn filter(&mut self, key: [u16; 3]) -> bool; } impl T> SearchFetcher for LazyCell { fn fetch(&mut self, key: [u16; 3]) -> Option<(&[Datum; 32], &[bool; 32])> { LazyCell::force_mut(self).fetch(key) } - fn filter(&self, key: [u16; 3]) -> bool { - LazyCell::force(self).filter(key) + fn filter(&mut self, key: [u16; 3]) -> bool { + LazyCell::force_mut(self).filter(key) } } diff --git a/tests/general/external_build.slt b/tests/general/external_build.slt index 0fb9d7f1..81dc35d8 100644 --- a/tests/general/external_build.slt +++ b/tests/general/external_build.slt @@ -1,6 +1,3 @@ -statement ok -DROP TABLE IF EXISTS t, vector_centroid, halfvec_centroid, real_centroid, bad_type_centroid, bad_duplicate_id; - statement ok CREATE TABLE t (val0 vector(3), val1 halfvec(3)); diff --git a/tests/general/partition.slt b/tests/general/partition.slt index a3d9a220..1e681c5e 100644 --- a/tests/general/partition.slt +++ b/tests/general/partition.slt @@ -1,6 +1,3 @@ -statement ok -DROP TABLE IF EXISTS id_789, id_456, id_123, t; - # partition table statement ok CREATE TABLE t (val vector(3), category_id int) PARTITION BY LIST(category_id); diff --git a/tests/general/rerank_in_index.slt b/tests/general/rerank_in_index.slt index 9258f44d..02d1e954 100644 --- a/tests/general/rerank_in_index.slt +++ b/tests/general/rerank_in_index.slt @@ -1,6 +1,3 @@ -statement ok -DROP TABLE IF EXISTS t_expr, t_column; - statement ok SET enable_seqscan = off; diff --git a/tests/general/rerank_in_table.slt b/tests/general/rerank_in_table.slt index 8bab097c..e9287976 100644 --- a/tests/general/rerank_in_table.slt +++ b/tests/general/rerank_in_table.slt @@ -1,6 +1,3 @@ -statement ok -DROP TABLE IF EXISTS t_expr, t_column; - statement ok SET enable_seqscan = off;