Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3280,6 +3280,9 @@ def _create_index_impl(
index_uuid: Optional[str] = None,
*,
target_partition_size: Optional[int] = None,
streaming_sample_rate: Optional[int] = None,
streaming_coreset_rate: Optional[int] = None,
streaming_refine_passes: Optional[int] = None,
skip_transpose: bool = False,
require_commit: bool = True,
**kwargs,
Expand Down Expand Up @@ -3491,6 +3494,12 @@ def _create_index_impl(
kwargs["num_partitions"] = num_partitions
if target_partition_size is not None:
kwargs["target_partition_size"] = target_partition_size
if streaming_sample_rate is not None:
kwargs["streaming_sample_rate"] = streaming_sample_rate
if streaming_coreset_rate is not None:
kwargs["streaming_coreset_rate"] = streaming_coreset_rate
if streaming_refine_passes is not None:
kwargs["streaming_refine_passes"] = streaming_refine_passes

if (precomputed_partition_dataset is not None) and (ivf_centroids is None):
raise ValueError(
Expand Down Expand Up @@ -3652,6 +3661,9 @@ def create_index(
index_uuid: Optional[str] = None,
*,
target_partition_size: Optional[int] = None,
streaming_sample_rate: Optional[int] = None,
streaming_coreset_rate: Optional[int] = None,
streaming_refine_passes: Optional[int] = None,
skip_transpose: bool = False,
progress_callback: Optional[Callable[[IndexProgress], None]] = None,
**kwargs,
Expand Down Expand Up @@ -3740,6 +3752,19 @@ def create_index(
The target partition size. If set, the number of partitions will be computed
based on the target partition size.
Otherwise, the target partition size will be set by index type.
streaming_sample_rate : int, optional
If set below ``sample_rate``, IVF kmeans trains incrementally and samples
at most ``num_partitions * streaming_sample_rate`` vectors per step. For
``num_partitions > 256``, chunks are compressed into a weighted coreset
and final centroids are trained with weighted hierarchical kmeans.
streaming_coreset_rate : int, optional
If set, controls the final weighted coreset budget independently from
``streaming_sample_rate``. The budget is
``num_partitions * streaming_coreset_rate``.
streaming_refine_passes : int, optional
Number of extra streaming Lloyd refinement passes to run after streaming
coreset training. Each pass loads at most
``num_partitions * streaming_sample_rate`` raw vectors at a time.
kwargs :
Parameters passed to the index building process.

Expand Down Expand Up @@ -3861,6 +3886,9 @@ def create_index(
fragment_ids=fragment_ids,
index_uuid=index_uuid,
target_partition_size=target_partition_size,
streaming_sample_rate=streaming_sample_rate,
streaming_coreset_rate=streaming_coreset_rate,
streaming_refine_passes=streaming_refine_passes,
skip_transpose=skip_transpose,
require_commit=True,
**kwargs,
Expand Down Expand Up @@ -3895,6 +3923,9 @@ def create_index_uncommitted(
index_uuid: Optional[str] = None,
*,
target_partition_size: Optional[int] = None,
streaming_sample_rate: Optional[int] = None,
streaming_coreset_rate: Optional[int] = None,
streaming_refine_passes: Optional[int] = None,
skip_transpose: bool = False,
**kwargs,
) -> Index:
Expand Down Expand Up @@ -3950,6 +3981,9 @@ def create_index_uncommitted(
fragment_ids=fragment_ids,
index_uuid=index_uuid,
target_partition_size=target_partition_size,
streaming_sample_rate=streaming_sample_rate,
streaming_coreset_rate=streaming_coreset_rate,
streaming_refine_passes=streaming_refine_passes,
skip_transpose=skip_transpose,
require_commit=False,
**kwargs,
Expand Down
10 changes: 10 additions & 0 deletions python/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4293,6 +4293,16 @@ fn prepare_vector_index_params(
pq_params.max_iters = max_iters;
}

if let Some(streaming_sample_rate) = kwargs.get_item("streaming_sample_rate")? {
ivf_params.streaming_sample_rate = Some(streaming_sample_rate.extract()?);
}
if let Some(streaming_coreset_rate) = kwargs.get_item("streaming_coreset_rate")? {
ivf_params.streaming_coreset_rate = Some(streaming_coreset_rate.extract()?);
}
if let Some(streaming_refine_passes) = kwargs.get_item("streaming_refine_passes")? {
ivf_params.streaming_refine_passes = streaming_refine_passes.extract()?;
}

// Parse IVF params
if let Some(n) = kwargs.get_item("num_partitions")? {
ivf_params.num_partitions = Some(n.extract()?)
Expand Down
30 changes: 30 additions & 0 deletions rust/lance-index/src/vector/ivf/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,33 @@ pub struct IvfBuildParams {

pub sample_rate: usize,

/// Optional per-step sample rate for streaming IVF kmeans training.
///
/// When set, IVF training loads at most `num_partitions * streaming_sample_rate`
/// vectors at a time. For `num_partitions > 256`, each chunk is compressed into
/// a weighted coreset and final centroids are trained with weighted hierarchical
/// kmeans over the coreset. The coreset budget is also bounded by this rate by
/// default so large partition counts can control peak memory by lowering
/// `streaming_sample_rate`. The total number of sampled vectors remains bounded
/// by `num_partitions * sample_rate`.
pub streaming_sample_rate: Option<usize>,

/// Optional coreset rate for streaming IVF kmeans training.
///
/// When set, the final weighted coreset budget is
/// `num_partitions * streaming_coreset_rate`, independent of
/// `streaming_sample_rate`. The streaming chunk size is still controlled by
/// `streaming_sample_rate`.
pub streaming_coreset_rate: Option<usize>,

/// Number of extra streaming Lloyd refinement passes to run after streaming
/// coreset training.
///
/// Each pass reuses the same sampled vectors and only loads
/// `num_partitions * streaming_sample_rate` raw vectors at a time. This is
/// experimental and defaults to 0 to preserve existing behavior.
pub streaming_refine_passes: usize,

/// Precomputed partitions file (row_id -> partition_id)
/// mutually exclusive with `precomputed_shuffle_buffers`
pub precomputed_partitions_file: Option<String>,
Expand Down Expand Up @@ -67,6 +94,9 @@ impl Default for IvfBuildParams {
centroids: None,
retrain: false,
sample_rate: 256, // See faiss
streaming_sample_rate: None,
streaming_coreset_rate: None,
streaming_refine_passes: 0,
precomputed_partitions_file: None,
precomputed_shuffle_buffers: None,
shuffle_partition_batches: 1024 * 10,
Expand Down
128 changes: 128 additions & 0 deletions rust/lance-index/src/vector/kmeans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,24 @@ impl KMeansAlgo<u8> for KModeAlgo {
}
}

/// Cluster id assignment for each vector in a batch.
pub type KMeansMembership = Vec<Option<u32>>;

/// Distance from each vector to its assigned centroid.
pub type KMeansDistances = Vec<Option<f32>>;

/// Maximum assignment distance per centroid.
pub type KMeansClusterRadii = Vec<f32>;

/// Sum of assignment distances per centroid.
pub type KMeansClusterLosses = Vec<f64>;

/// Batch assignment results with per-centroid radii and losses.
pub type KMeansMembershipAndLoss = (KMeansMembership, KMeansClusterRadii, KMeansClusterLosses);

/// Batch assignment results with per-vector distances.
pub type KMeansMembershipAndDistances = (KMeansMembership, KMeansDistances);

/// KMeans implementation for Apache Arrow Arrays.
#[derive(Debug, Clone)]
pub struct KMeans {
Expand Down Expand Up @@ -636,6 +654,116 @@ impl KMeans {
Self::new_with_params(data, k, &params)
}

/// Assign a batch of vectors to these centroids and return membership, radius, and loss.
pub fn compute_membership_and_loss(
&self,
data: &FixedSizeListArray,
) -> arrow::error::Result<KMeansMembershipAndLoss> {
let (membership, distances) = self.compute_membership_and_distances(data)?;
let k = self.centroids.len() / self.dimension;
let mut cluster_radius: Vec<f32> = vec![0.0_f32; k];
let mut losses = vec![0.0; k];
for (cluster_id, dist) in membership.iter().zip(distances.iter()) {
if let (Some(cluster_id), Some(dist)) = (cluster_id, dist) {
let cluster_id = *cluster_id as usize;
cluster_radius[cluster_id] = cluster_radius[cluster_id].max(*dist);
losses[cluster_id] += *dist as f64;
}
}
Ok((membership, cluster_radius, losses))
}

/// Assign a batch of vectors to these centroids and return per-vector distances.
pub fn compute_membership_and_distances(
&self,
data: &FixedSizeListArray,
) -> arrow::error::Result<KMeansMembershipAndDistances> {
if data.value_length() as usize != self.dimension {
return Err(ArrowError::InvalidArgumentError(format!(
"KMeans: data dimension {} does not match centroid dimension {}",
data.value_length(),
self.dimension
)));
}

let index = SimpleIndex::may_train_index(
self.centroids.clone(),
self.dimension,
self.distance_type,
)
.map_err(|e| ArrowError::ExternalError(Box::new(e)))?;
match (
data.value_type(),
self.centroids.data_type(),
self.distance_type,
) {
(DataType::Float16, DataType::Float16, _) => {
let data_values = data.values().as_primitive::<Float16Type>().values();
let centroids = self.centroids.as_primitive::<Float16Type>().values();
Ok(KMeansAlgoFloat::<Float16Type>::compute_membership_and_dist(
centroids,
data_values,
self.dimension,
self.distance_type,
0.0,
None,
index.as_ref(),
))
}
(DataType::Float32, DataType::Float32, _) => {
let data_values = data.values().as_primitive::<Float32Type>().values();
let centroids = self.centroids.as_primitive::<Float32Type>().values();
Ok(KMeansAlgoFloat::<Float32Type>::compute_membership_and_dist(
centroids,
data_values,
self.dimension,
self.distance_type,
0.0,
None,
index.as_ref(),
))
}
(DataType::Float64, DataType::Float64, _) => {
let data_values = data.values().as_primitive::<Float64Type>().values();
let centroids = self.centroids.as_primitive::<Float64Type>().values();
Ok(KMeansAlgoFloat::<Float64Type>::compute_membership_and_dist(
centroids,
data_values,
self.dimension,
self.distance_type,
0.0,
None,
index.as_ref(),
))
}
(DataType::UInt8, DataType::UInt8, DistanceType::Hamming) => {
let data_values = data.values().as_primitive::<UInt8Type>().values();
let centroids = self.centroids.as_primitive::<UInt8Type>().values();
Ok(KModeAlgo::compute_membership_and_dist(
centroids,
data_values,
self.dimension,
self.distance_type,
0.0,
None,
index.as_ref(),
))
}
_ => Err(ArrowError::InvalidArgumentError(format!(
"KMeans: can not compute membership for data type {} with centroid type {} and distance type {}",
data.value_type(),
self.centroids.data_type(),
self.distance_type
))),
}
}

/// Compute the kmeans loss for a batch of vectors against these centroids.
pub fn compute_loss(&self, data: &FixedSizeListArray) -> arrow::error::Result<f64> {
let (_, _, losses) = self.compute_membership_and_loss(data)?;
Ok(losses.iter().sum())
}

fn train_kmeans<T: ArrowNumericType, Algo: KMeansAlgo<T::Native>>(
data: &FixedSizeListArray,
k: usize,
Expand Down
3 changes: 3 additions & 0 deletions rust/lance/src/index/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1841,6 +1841,9 @@ fn derive_ivf_params(ivf_model: &IvfModel) -> IvfBuildParams {
#[allow(deprecated)]
retrain: false, // Don't retrain since we have centroids
sample_rate: 256, // Default
streaming_sample_rate: None,
streaming_coreset_rate: None,
streaming_refine_passes: 0,
precomputed_partitions_file: None,
precomputed_shuffle_buffers: None,
shuffle_partition_batches: 1024 * 10, // Default
Expand Down
Loading
Loading