From 49ece3640256ab1a15625b6374f5c9b55bbe254d Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Thu, 11 Jun 2026 16:55:47 -0400 Subject: [PATCH] remove len parameter from `ScalarFnArray::try_new` Signed-off-by: Connor Tsui --- encodings/datetime-parts/src/compute/rules.rs | 7 +- encodings/runend/src/rules.rs | 3 +- .../src/arrays/chunked/compute/rules.rs | 12 +-- vortex-array/src/arrays/dict/compute/rules.rs | 17 ++-- .../arrays/fixed_size_list/tests/nested.rs | 1 - vortex-array/src/arrays/scalar_fn/array.rs | 45 ++++++--- vortex-array/src/arrays/scalar_fn/plugin.rs | 2 +- vortex-array/src/arrays/scalar_fn/rules.rs | 8 +- .../src/arrays/scalar_fn/vtable/operations.rs | 47 +++++++++- vortex-array/src/expression.rs | 3 +- .../src/encodings/turboquant/compress.rs | 9 +- .../src/encodings/turboquant/tests/compute.rs | 17 ++-- .../encodings/turboquant/tests/nullable.rs | 2 +- .../encodings/turboquant/tests/structural.rs | 3 +- .../src/scalar_fns/cosine_similarity.rs | 94 +++++++++---------- vortex-tensor/src/scalar_fns/inner_product.rs | 92 +++++++++--------- vortex-tensor/src/scalar_fns/l2_denorm.rs | 46 +++++---- vortex-tensor/src/scalar_fns/l2_norm.rs | 28 +++--- .../src/scalar_fns/sorf_transform/mod.rs | 8 +- .../src/scalar_fns/sorf_transform/tests.rs | 34 +++---- vortex-tensor/src/utils.rs | 3 +- vortex-tensor/src/vector_search.rs | 2 +- vortex-turboquant/src/scalar_fns/decode.rs | 3 +- vortex-turboquant/src/scalar_fns/encode.rs | 3 +- vortex-turboquant/src/vector/normalize.rs | 4 +- 25 files changed, 250 insertions(+), 243 deletions(-) diff --git a/encodings/datetime-parts/src/compute/rules.rs b/encodings/datetime-parts/src/compute/rules.rs index 46bc2221e5d..bba19ceea47 100644 --- a/encodings/datetime-parts/src/compute/rules.rs +++ b/encodings/datetime-parts/src/compute/rules.rs @@ -133,10 +133,9 @@ impl ArrayParentReduceRule for DTPComparisonPushDownRule { } } - let result = - ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children, parent.len())? - .into_array() - .optimize()?; + let result = ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children)? + .into_array() + .optimize()?; Ok(Some(result)) } diff --git a/encodings/runend/src/rules.rs b/encodings/runend/src/rules.rs index cd80e42db46..ac7d6038146 100644 --- a/encodings/runend/src/rules.rs +++ b/encodings/runend/src/rules.rs @@ -77,8 +77,7 @@ impl ArrayParentReduceRule for RunEndScalarFnRule { } let new_values = - ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children, values_len)? - .into_array(); + ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children)?.into_array(); Ok(Some( unsafe { diff --git a/vortex-array/src/arrays/chunked/compute/rules.rs b/vortex-array/src/arrays/chunked/compute/rules.rs index d8d324a8e86..712578255bb 100644 --- a/vortex-array/src/arrays/chunked/compute/rules.rs +++ b/vortex-array/src/arrays/chunked/compute/rules.rs @@ -48,13 +48,9 @@ impl ArrayParentReduceRule for ChunkedUnaryScalarFnPushDownRule { let new_chunks: Vec<_> = array .iter_chunks() .map(|chunk| { - ScalarFnArray::try_new( - parent.scalar_fn().clone(), - vec![chunk.clone()], - chunk.len(), - )? - .into_array() - .optimize() + ScalarFnArray::try_new(parent.scalar_fn().clone(), vec![chunk.clone()])? + .into_array() + .optimize() }) .try_collect()?; @@ -104,7 +100,7 @@ impl ArrayParentReduceRule for ChunkedConstantScalarFnPushDownRule { }) .collect(); - ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children, chunk.len())? + ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children)? .into_array() .optimize() }) diff --git a/vortex-array/src/arrays/dict/compute/rules.rs b/vortex-array/src/arrays/dict/compute/rules.rs index 5150a22271a..024240e4030 100644 --- a/vortex-array/src/arrays/dict/compute/rules.rs +++ b/vortex-array/src/arrays/dict/compute/rules.rs @@ -126,10 +126,9 @@ impl ArrayParentReduceRule for DictionaryScalarFnValuesPushDownRule { } } - let new_values = - ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children, values_len)? - .into_array() - .optimize()?; + let new_values = ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children)? + .into_array() + .optimize()?; // We can only push down null-sensitive functions when we have all-valid codes. // In these cases, we cannot have the codes influence the nullability of the output DType. @@ -192,13 +191,9 @@ impl ArrayParentReduceRule for DictionaryScalarFnCodesPullUpRule { } } - let new_values = ScalarFnArray::try_new( - parent.scalar_fn().clone(), - new_children, - array.values().len(), - )? - .into_array() - .optimize()?; + let new_values = ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children)? + .into_array() + .optimize()?; let new_dict = unsafe { DictArray::new_unchecked(array.codes().clone(), new_values) }.into_array(); diff --git a/vortex-array/src/arrays/fixed_size_list/tests/nested.rs b/vortex-array/src/arrays/fixed_size_list/tests/nested.rs index efd6f630580..8fca62c0097 100644 --- a/vortex-array/src/arrays/fixed_size_list/tests/nested.rs +++ b/vortex-array/src/arrays/fixed_size_list/tests/nested.rs @@ -270,7 +270,6 @@ fn test_fsl_of_fsl_with_nulls() { #[test] fn test_deeply_nested_fsl() { - let _len = 2; let list_size = 2; // Create a 3-level nested FSL: FSL[FSL[FSL[i32]]]. diff --git a/vortex-array/src/arrays/scalar_fn/array.rs b/vortex-array/src/arrays/scalar_fn/array.rs index 2c77be3c329..a999a25cd51 100644 --- a/vortex-array/src/arrays/scalar_fn/array.rs +++ b/vortex-array/src/arrays/scalar_fn/array.rs @@ -6,6 +6,7 @@ use std::fmt::Formatter; use vortex_error::VortexExpect; use vortex_error::VortexResult; +use vortex_error::vortex_bail; use vortex_error::vortex_ensure; use crate::ArrayRef; @@ -30,19 +31,6 @@ impl Display for ScalarFnData { } impl ScalarFnData { - /// Create a new ScalarFnArray from a scalar function and its children. - pub fn build( - scalar_fn: ScalarFnRef, - children: Vec, - len: usize, - ) -> VortexResult { - vortex_ensure!( - children.iter().all(|c| c.len() == len), - "ScalarFnArray must have children equal to the array length" - ); - Ok(Self { scalar_fn }) - } - /// Get the scalar function bound to this array. #[inline(always)] pub fn scalar_fn(&self) -> &ScalarFnRef { @@ -85,14 +73,26 @@ impl> ScalarFnArrayExt for T {} impl Array { /// Create a new ScalarFnArray from a scalar function and its children. - pub fn try_new( + pub fn try_new(scalar_fn: ScalarFnRef, children: Vec) -> VortexResult { + let len = Self::infer_len(&children)?; + Self::try_new_with_len(scalar_fn, children, len) + } + + /// Create a new ScalarFnArray from a scalar function, children, and an explicit length. + /// + /// This is needed for zero-child scalar functions and deserialization paths where there is no + /// child array to infer the length from. + pub fn try_new_with_len( scalar_fn: ScalarFnRef, children: Vec, len: usize, ) -> VortexResult { + Self::validate_children_len(&children, len)?; let arg_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect(); let dtype = scalar_fn.return_dtype(&arg_dtypes)?; - let data = ScalarFnData::build(scalar_fn.clone(), children.clone(), len)?; + let data = ScalarFnData { + scalar_fn: scalar_fn.clone(), + }; let vtable = ScalarFn { id: scalar_fn.id() }; Ok(unsafe { Array::from_parts_unchecked( @@ -101,4 +101,19 @@ impl Array { ) }) } + + fn infer_len(children: &[ArrayRef]) -> VortexResult { + let Some(child) = children.first() else { + vortex_bail!("ScalarFnArray length cannot be inferred without children"); + }; + Ok(child.len()) + } + + fn validate_children_len(children: &[ArrayRef], len: usize) -> VortexResult<()> { + vortex_ensure!( + children.iter().all(|c| c.len() == len), + "ScalarFnArray must have children equal to the array length" + ); + Ok(()) + } } diff --git a/vortex-array/src/arrays/scalar_fn/plugin.rs b/vortex-array/src/arrays/scalar_fn/plugin.rs index c685503a370..044f79362cd 100644 --- a/vortex-array/src/arrays/scalar_fn/plugin.rs +++ b/vortex-array/src/arrays/scalar_fn/plugin.rs @@ -82,7 +82,7 @@ impl ArrayPlugin for ScalarFnArrayPlugi let parts = ::deserialize( &self.0, dtype, len, metadata, children, session, )?; - Ok(ScalarFnArray::try_new( + Ok(ScalarFnArray::try_new_with_len( TypedScalarFnInstance::new(self.0.clone(), parts.options).erased(), parts.children, len, diff --git a/vortex-array/src/arrays/scalar_fn/rules.rs b/vortex-array/src/arrays/scalar_fn/rules.rs index 1e9563cf9de..abc8e18e762 100644 --- a/vortex-array/src/arrays/scalar_fn/rules.rs +++ b/vortex-array/src/arrays/scalar_fn/rules.rs @@ -84,7 +84,8 @@ impl ArrayParentReduceRule for ScalarFnSliceReduceRule { .collect::>()?; Ok(Some( - ScalarFnArray::try_new(array.scalar_fn().clone(), children, range.len())?.into_array(), + ScalarFnArray::try_new_with_len(array.scalar_fn().clone(), children, range.len())? + .into_array(), )) } } @@ -142,7 +143,7 @@ impl ReduceCtx for ArrayReduceCtx { children: &[ReduceNodeRef], ) -> VortexResult { Ok(Arc::new( - ScalarFnArray::try_new( + ScalarFnArray::try_new_with_len( scalar_fn, children .iter() @@ -191,8 +192,7 @@ impl ArrayParentReduceRule for ScalarFnUnaryFilterPushDownRule { .try_collect()?; let new_array = - ScalarFnArray::try_new(child.scalar_fn().clone(), new_children, parent.len())? - .into_array(); + ScalarFnArray::try_new(child.scalar_fn().clone(), new_children)?.into_array(); return Ok(Some(new_array)); } diff --git a/vortex-array/src/arrays/scalar_fn/vtable/operations.rs b/vortex-array/src/arrays/scalar_fn/vtable/operations.rs index 1f710af3cdf..67de19d936f 100644 --- a/vortex-array/src/arrays/scalar_fn/vtable/operations.rs +++ b/vortex-array/src/arrays/scalar_fn/vtable/operations.rs @@ -65,9 +65,12 @@ mod tests { use crate::arrays::BoolArray; use crate::arrays::PrimitiveArray; use crate::arrays::ScalarFnArray; + use crate::arrays::scalar_fn::ScalarFnArrayExt; use crate::assert_arrays_eq; + use crate::scalar::Scalar; use crate::scalar_fn::TypedScalarFnInstance; use crate::scalar_fn::fns::binary::Binary; + use crate::scalar_fn::fns::literal::Literal; use crate::scalar_fn::fns::operators::Operator; use crate::validity::Validity; @@ -77,7 +80,9 @@ mod tests { let rhs = buffer![10i32, 20, 30].into_array(); let scalar_fn = TypedScalarFnInstance::new(Binary, Operator::Add).erased(); - let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 3)?; + let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs])?; + + assert_eq!(scalar_fn_array.len(), 3); let result = scalar_fn_array .into_array() @@ -89,13 +94,47 @@ mod tests { Ok(()) } + #[test] + fn test_scalar_fn_inferred_len_rejects_mismatched_children() { + let lhs = buffer![1i32, 2, 3].into_array(); + let rhs = buffer![10i32, 20].into_array(); + + let scalar_fn = TypedScalarFnInstance::new(Binary, Operator::Add).erased(); + let err = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs]) + .expect_err("ScalarFnArray::try_new must reject mismatched child lengths"); + + assert!( + err.to_string() + .contains("ScalarFnArray must have children equal to the array length") + ); + } + + #[test] + fn test_scalar_fn_without_children_requires_explicit_len() -> VortexResult<()> { + let scalar_fn = TypedScalarFnInstance::new(Literal, Scalar::from(1i32)).erased(); + + let Err(err) = ScalarFnArray::try_new(scalar_fn.clone(), vec![]) else { + panic!("ScalarFnArray::try_new should reject zero children"); + }; + assert!( + err.to_string() + .contains("ScalarFnArray length cannot be inferred without children") + ); + + let scalar_fn_array = ScalarFnArray::try_new_with_len(scalar_fn, vec![], 3)?; + assert_eq!(scalar_fn_array.len(), 3); + assert_eq!(scalar_fn_array.child_count(), 0); + + Ok(()) + } + #[test] fn test_scalar_fn_mul() -> VortexResult<()> { let lhs = buffer![2i32, 3, 4].into_array(); let rhs = buffer![5i32, 6, 7].into_array(); let scalar_fn = TypedScalarFnInstance::new(Binary, Operator::Mul).erased(); - let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 3)?; + let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs])?; let result = scalar_fn_array .into_array() @@ -117,7 +156,7 @@ mod tests { .into_array(); let scalar_fn = TypedScalarFnInstance::new(Binary, Operator::Add).erased(); - let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 3)?; + let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs])?; let result = scalar_fn_array .into_array() @@ -139,7 +178,7 @@ mod tests { let rhs = buffer![2i32, 5, 1].into_array(); let scalar_fn = TypedScalarFnInstance::new(Binary, Operator::Eq).erased(); - let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 3)?; + let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs])?; let result = scalar_fn_array .into_array() diff --git a/vortex-array/src/expression.rs b/vortex-array/src/expression.rs index 780bfcf9db1..d6b8739abf6 100644 --- a/vortex-array/src/expression.rs +++ b/vortex-array/src/expression.rs @@ -35,7 +35,8 @@ impl ArrayRef { // And wrap the scalar function up in an array. let array = - ScalarFnArray::try_new(expr.scalar_fn().clone(), children, self.len())?.into_array(); + ScalarFnArray::try_new_with_len(expr.scalar_fn().clone(), children, self.len())? + .into_array(); // Optimize the resulting array's root. array.optimize() diff --git a/vortex-tensor/src/encodings/turboquant/compress.rs b/vortex-tensor/src/encodings/turboquant/compress.rs index e656ba18822..4f9c9153e11 100644 --- a/vortex-tensor/src/encodings/turboquant/compress.rs +++ b/vortex-tensor/src/encodings/turboquant/compress.rs @@ -91,7 +91,6 @@ pub fn turboquant_encode( let l2_denorm = normalize_as_l2_denorm(input, ctx)?; let normalized = l2_denorm.child_at(0).clone(); let norms = l2_denorm.child_at(1).clone(); - let num_rows = l2_denorm.len(); let normalized_ext = normalized .as_opt::() @@ -102,7 +101,7 @@ pub fn turboquant_encode( // SAFETY: TurboQuant is a lossy approximation of the normalized child, so we intentionally // bypass the strict normalized-row validation when reattaching the stored norms. - Ok(unsafe { L2Denorm::new_array_unchecked(tq, norms, num_rows) }?.into_array()) + Ok(unsafe { L2Denorm::new_array_unchecked(tq, norms) }?.into_array()) } /// Encode a non-nullable, L2-normalized [`Vector`](crate::vector::Vector) extension array into a @@ -164,9 +163,7 @@ pub unsafe fn turboquant_encode_unchecked( dimensions: dimension, element_ptype, }; - return Ok( - SorfTransform::try_new_array(&sorf_options, empty_padded_vector, 0)?.into_array(), - ); + return Ok(SorfTransform::try_new_array(&sorf_options, empty_padded_vector)?.into_array()); } let core = turboquant_quantize_core(&fsl, seed, config.bit_width, config.num_rounds, ctx)?; @@ -180,7 +177,7 @@ pub unsafe fn turboquant_encode_unchecked( dimensions: dimension, element_ptype, }; - Ok(SorfTransform::try_new_array(&sorf_options, padded_vector, num_rows)?.into_array()) + Ok(SorfTransform::try_new_array(&sorf_options, padded_vector)?.into_array()) } /// Shared intermediate results from the quantization loop. diff --git a/vortex-tensor/src/encodings/turboquant/tests/compute.rs b/vortex-tensor/src/encodings/turboquant/tests/compute.rs index 4d670695eaf..ba7ab5fb568 100644 --- a/vortex-tensor/src/encodings/turboquant/tests/compute.rs +++ b/vortex-tensor/src/encodings/turboquant/tests/compute.rs @@ -17,19 +17,17 @@ use crate::scalar_fns::l2_norm::L2Norm; fn execute_l2_norm( input: ArrayRef, - len: usize, ctx: &mut vortex_array::ExecutionCtx, ) -> VortexResult { - L2Norm::try_new_array(input, len)?.into_array().execute(ctx) + L2Norm::try_new_array(input)?.into_array().execute(ctx) } fn execute_cosine_similarity( lhs: ArrayRef, rhs: ArrayRef, - len: usize, ctx: &mut vortex_array::ExecutionCtx, ) -> VortexResult { - CosineSimilarity::try_new_array(lhs, rhs, len)? + CosineSimilarity::try_new_array(lhs, rhs)? .into_array() .execute(ctx) } @@ -133,7 +131,7 @@ fn l2_norm_readthrough() -> VortexResult<()> { } // Also verify L2Norm readthrough shortcut works. - let norms = execute_l2_norm(encoded, 10, &mut ctx)?; + let norms = execute_l2_norm(encoded, &mut ctx)?; assert_eq!(norms.as_slice::(), stored_norms); assert_eq!(norms.len(), 10); Ok(()) @@ -154,14 +152,14 @@ fn l2_norm_readthrough_is_authoritative_for_lossy_storage() -> VortexResult<()> let (_sorf_child, norms_child) = unwrap_l2denorm(&encoded); let stored_norms: PrimitiveArray = norms_child.execute(&mut ctx)?; - let encoded_norms = execute_l2_norm(encoded.clone(), num_rows, &mut ctx)?; + let encoded_norms = execute_l2_norm(encoded.clone(), &mut ctx)?; assert_eq!( encoded_norms.as_slice::(), stored_norms.as_slice::() ); let decoded = encoded.execute::(&mut ctx)?.into_array(); - let decoded_norms = execute_l2_norm(decoded, num_rows, &mut ctx)?; + let decoded_norms = execute_l2_norm(decoded, &mut ctx)?; let max_gap = stored_norms .as_slice::() .iter() @@ -189,10 +187,9 @@ fn cosine_similarity_readthrough_is_authoritative_for_lossy_storage() -> VortexR let mut ctx = SESSION.create_execution_ctx(); let encoded = turboquant_encode(ext, &config, &mut ctx)?; - let encoded_cos = - execute_cosine_similarity(encoded.clone(), encoded.clone(), num_rows, &mut ctx)?; + let encoded_cos = execute_cosine_similarity(encoded.clone(), encoded.clone(), &mut ctx)?; let decoded = encoded.execute::(&mut ctx)?.into_array(); - let decoded_cos = execute_cosine_similarity(decoded.clone(), decoded, num_rows, &mut ctx)?; + let decoded_cos = execute_cosine_similarity(decoded.clone(), decoded, &mut ctx)?; let decoded_values = decoded_cos.as_slice::(); assert!( diff --git a/vortex-tensor/src/encodings/turboquant/tests/nullable.rs b/vortex-tensor/src/encodings/turboquant/tests/nullable.rs index 92eb0e7b152..c7d28c4f14f 100644 --- a/vortex-tensor/src/encodings/turboquant/tests/nullable.rs +++ b/vortex-tensor/src/encodings/turboquant/tests/nullable.rs @@ -120,7 +120,7 @@ fn nullable_l2_norm_readthrough() -> VortexResult<()> { let mut ctx = SESSION.create_execution_ctx(); let encoded = turboquant_encode(ext, &config, &mut ctx)?; - let norm_sfn = L2Norm::try_new_array(encoded, 5)?; + let norm_sfn = L2Norm::try_new_array(encoded)?; let norms: PrimitiveArray = norm_sfn.into_array().execute(&mut ctx)?; let orig_prim = fsl.elements().clone().execute::(&mut ctx)?; diff --git a/vortex-tensor/src/encodings/turboquant/tests/structural.rs b/vortex-tensor/src/encodings/turboquant/tests/structural.rs index bc9a5e207f1..4f81df6d416 100644 --- a/vortex-tensor/src/encodings/turboquant/tests/structural.rs +++ b/vortex-tensor/src/encodings/turboquant/tests/structural.rs @@ -303,8 +303,7 @@ fn sorf_transform_roundtrip_isolation() -> VortexResult<()> { dimensions: dim as u32, element_ptype: vortex_array::dtype::PType::F32, }; - let sorf_array = - SorfTransform::try_new_array(&sorf_options, padded_vector.into_array(), num_rows)?; + let sorf_array = SorfTransform::try_new_array(&sorf_options, padded_vector.into_array())?; let mut ctx = SESSION.create_execution_ctx(); let result: ExtensionArray = sorf_array.into_array().execute(&mut ctx)?; diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index 0b3176915fd..c77c421e493 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -71,8 +71,8 @@ impl CosineSimilarity { /// /// Returns an error if the [`ScalarFnArray`] cannot be constructed (e.g. due to dtype /// mismatches). - pub fn try_new_array(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult { - ScalarFnArray::try_new(CosineSimilarity::new().erased(), vec![lhs, rhs], len) + pub fn try_new_array(lhs: ArrayRef, rhs: ArrayRef) -> VortexResult { + ScalarFnArray::try_new(CosineSimilarity::new().erased(), vec![lhs, rhs]) } } @@ -141,9 +141,9 @@ impl ScalarFnVTable for CosineSimilarity { let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?; // Compute inner product and norms as columnar operations, and propagate the options. - let norm_lhs_arr = L2Norm::try_new_array(lhs_ref.clone(), len)?; - let norm_rhs_arr = L2Norm::try_new_array(rhs_ref.clone(), len)?; - let dot_arr = InnerProduct::try_new_array(lhs_ref, rhs_ref, len)?; + let norm_lhs_arr = L2Norm::try_new_array(lhs_ref.clone())?; + let norm_rhs_arr = L2Norm::try_new_array(rhs_ref.clone())?; + let dot_arr = InnerProduct::try_new_array(lhs_ref, rhs_ref)?; // Execute to get the inner product and norms of the arrays. We only fully decompress // because we need to perform special logic (guard against 0) during division. @@ -239,7 +239,7 @@ impl CosineSimilarity { // `L2Denorm` makes the normalized children authoritative, so their dot product is the // cosine similarity even for lossy storage wrappers, except that a zero stored norm still // represents a zero vector. - let dot: PrimitiveArray = InnerProduct::try_new_array(normalized_l, normalized_r, len)? + let dot: PrimitiveArray = InnerProduct::try_new_array(normalized_l, normalized_r)? .into_array() .execute(ctx)?; let norms_l: PrimitiveArray = norms_l.execute(ctx)?; @@ -279,12 +279,12 @@ impl CosineSimilarity { let (normalized, denorm_norms) = extract_l2_denorm_children(denorm_ref); - let dot_arr = InnerProduct::try_new_array(normalized, plain_ref.clone(), len)?; + let dot_arr = InnerProduct::try_new_array(normalized, plain_ref.clone())?; let dot: PrimitiveArray = dot_arr.into_array().execute(ctx)?; let denorm_norms: PrimitiveArray = denorm_norms.execute(ctx)?; - let norm_arr = L2Norm::try_new_array(plain_ref.clone(), len)?; + let norm_arr = L2Norm::try_new_array(plain_ref.clone())?; let plain_norm: PrimitiveArray = norm_arr.into_array().execute(ctx)?; // TODO(connor): Ideally we would have a `SafeDiv` binary numeric operation. @@ -335,9 +335,9 @@ mod tests { use crate::utils::test_helpers::vector_array; /// Evaluates cosine similarity between two tensor arrays and returns the result as `Vec`. - fn eval_cosine_similarity(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult> { + fn eval_cosine_similarity(lhs: ArrayRef, rhs: ArrayRef) -> VortexResult> { let scalar_fn = CosineSimilarity::new().erased(); - let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], len)?; + let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs])?; let mut ctx = SESSION.create_execution_ctx(); let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?; Ok(prim.as_slice::().to_vec()) @@ -361,7 +361,7 @@ mod tests { )?; // Row 0: identical -> 1.0, row 1: orthogonal -> 0.0. - assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 0.0]); + assert_close(&eval_cosine_similarity(lhs, rhs)?, &[1.0, 0.0]); Ok(()) } @@ -381,7 +381,7 @@ mod tests { ) -> VortexResult<()> { let lhs = tensor_array(shape, lhs_elems)?; let rhs = tensor_array(shape, rhs_elems)?; - assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, expected); + assert_close(&eval_cosine_similarity(lhs, rhs)?, expected); Ok(()) } @@ -400,7 +400,7 @@ mod tests { fn self_similarity(#[case] shape: &[usize], #[case] elements: &[f64]) -> VortexResult<()> { let lhs = tensor_array(shape, elements)?; let rhs = tensor_array(shape, elements)?; - assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[1.0]); + assert_close(&eval_cosine_similarity(lhs, rhs)?, &[1.0]); Ok(()) } @@ -411,7 +411,7 @@ mod tests { let rhs = tensor_array(&[], &[5.0, -3.0])?; // Same sign -> 1.0, opposite sign -> -1.0. - assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, -1.0]); + assert_close(&eval_cosine_similarity(lhs, rhs)?, &[1.0, -1.0]); Ok(()) } @@ -431,7 +431,7 @@ mod tests { let rhs = lhs.clone(); assert_close( - &eval_cosine_similarity(lhs, rhs, 5)?, + &eval_cosine_similarity(lhs, rhs)?, &[1.0, 1.0, 1.0, 1.0, 1.0], ); Ok(()) @@ -451,10 +451,7 @@ mod tests { )?; let query = constant_tensor_array(&[3], &[1.0, 0.0, 0.0], 4)?; - assert_close( - &eval_cosine_similarity(data, query, 4)?, - &[1.0, 0.0, 0.0, 1.0], - ); + assert_close(&eval_cosine_similarity(data, query)?, &[1.0, 0.0, 0.0, 1.0]); Ok(()) } @@ -476,7 +473,7 @@ mod tests { )?; // Row 0: identical -> 1.0, row 1: orthogonal -> 0.0. - assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 0.0]); + assert_close(&eval_cosine_similarity(lhs, rhs)?, &[1.0, 0.0]); Ok(()) } @@ -493,10 +490,7 @@ mod tests { )?; let query = Vector::constant_array(&[1.0, 0.0, 0.0], 4)?; - assert_close( - &eval_cosine_similarity(data, query, 4)?, - &[1.0, 0.0, 0.0, 1.0], - ); + assert_close(&eval_cosine_similarity(data, query)?, &[1.0, 0.0, 0.0, 1.0]); Ok(()) } @@ -508,7 +502,7 @@ mod tests { let rhs = MaskedArray::try_new(rhs, Validity::from_iter([true, false]))?.into_array(); let scalar_fn = CosineSimilarity::new().erased(); - let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 2)?; + let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs])?; let mut ctx = SESSION.create_execution_ctx(); let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?; @@ -528,7 +522,7 @@ mod tests { let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; // Self-similarity should always be 1.0. - assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 1.0]); + assert_close(&eval_cosine_similarity(lhs, rhs)?, &[1.0, 1.0]); Ok(()) } @@ -540,7 +534,7 @@ mod tests { let lhs = l2_denorm_array(&[2], &[1.0, 0.0], &[3.0], &mut ctx)?; let rhs = l2_denorm_array(&[2], &[0.0, 1.0], &[4.0], &mut ctx)?; - assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[0.0]); + assert_close(&eval_cosine_similarity(lhs, rhs)?, &[0.0]); Ok(()) } @@ -552,7 +546,7 @@ mod tests { let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; // Row 0: dot([0.6, 0.8], [0.6, 0.8]) = 1.0, row 1: dot([0,0], [1,0]) = 0.0. - assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 0.0]); + assert_close(&eval_cosine_similarity(lhs, rhs)?, &[1.0, 0.0]); Ok(()) } @@ -565,7 +559,7 @@ mod tests { let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0], &mut ctx)?; let rhs = tensor_array(&[2], &[3.0, 4.0])?; - assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[1.0]); + assert_close(&eval_cosine_similarity(lhs, rhs)?, &[1.0]); Ok(()) } @@ -577,7 +571,7 @@ mod tests { let lhs = tensor_array(&[2], &[1.0, 0.0])?; let rhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0], &mut ctx)?; - assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[0.6]); + assert_close(&eval_cosine_similarity(lhs, rhs)?, &[0.6]); Ok(()) } @@ -589,10 +583,10 @@ mod tests { let normalized_r = tensor_array(&[2], &[0.6, 0.8, 1.0, 0.0])?; let norms_r = PrimitiveArray::from_option_iter([Some(5.0f64), None]).into_array(); - let rhs = L2Denorm::try_new_array(normalized_r, norms_r, 2, &mut ctx)?.into_array(); + let rhs = L2Denorm::try_new_array(normalized_r, norms_r, &mut ctx)?.into_array(); let scalar_fn = CosineSimilarity::new().erased(); - let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 2)?; + let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs])?; let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?; assert!(prim.is_valid(0, &mut ctx)?); @@ -611,16 +605,16 @@ mod tests { let norms_l = PrimitiveArray::from_iter([0.0f64]).into_array(); // SAFETY: This is a focused test that intentionally violates the unit-norm invariant by // pairing a nonzero normalized row with a stored norm of `0.0`, mimicking lossy storage. - let lhs = unsafe { L2Denorm::new_array_unchecked(normalized_l, norms_l, 1)? }.into_array(); + let lhs = unsafe { L2Denorm::new_array_unchecked(normalized_l, norms_l)? }.into_array(); let normalized_r = tensor_array(&[2], &[0.6, 0.8])?; let norms_r = PrimitiveArray::from_iter([0.0f64]).into_array(); // SAFETY: Same as above for the rhs operand. - let rhs = unsafe { L2Denorm::new_array_unchecked(normalized_r, norms_r, 1)? }.into_array(); + let rhs = unsafe { L2Denorm::new_array_unchecked(normalized_r, norms_r)? }.into_array(); // `dot(normalized_l, normalized_r) = 1.0`, but the authoritative stored norms are both // `0.0`, so cosine similarity must be `0.0`. - assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[0.0]); + assert_close(&eval_cosine_similarity(lhs, rhs)?, &[0.0]); Ok(()) } @@ -634,19 +628,19 @@ mod tests { let norms = PrimitiveArray::from_iter([0.0f64]).into_array(); // SAFETY: This is a focused test that intentionally pairs a nonzero normalized row with a // stored norm of `0.0`, mimicking lossy storage where the stored norm is authoritative. - let denorm = unsafe { L2Denorm::new_array_unchecked(normalized, norms, 1)? }.into_array(); + let denorm = unsafe { L2Denorm::new_array_unchecked(normalized, norms)? }.into_array(); let plain = tensor_array(&[2], &[1.0, 0.0])?; // Denorm on the lhs: `One { denorm: lhs, plain: rhs }`. assert_close( - &eval_cosine_similarity(denorm.clone(), plain.clone(), 1)?, + &eval_cosine_similarity(denorm.clone(), plain.clone())?, &[0.0], ); // Denorm on the rhs: `One { denorm: rhs, plain: lhs }`. The same zero-norm guard must // fire regardless of operand order. - assert_close(&eval_cosine_similarity(plain, denorm, 1)?, &[0.0]); + assert_close(&eval_cosine_similarity(plain, denorm)?, &[0.0]); Ok(()) } @@ -665,7 +659,7 @@ mod tests { ], )?; assert_close( - &eval_cosine_similarity(lhs, rhs, 4)?, + &eval_cosine_similarity(lhs, rhs)?, &[1.0 / 3.0, 1.0, 2.0 / 3.0, 8.0 / 9.0], ); Ok(()) @@ -685,7 +679,7 @@ mod tests { )?; let rhs = constant_tensor_array(&[3], &[1.0, 2.0, 2.0], 4)?; assert_close( - &eval_cosine_similarity(lhs, rhs, 4)?, + &eval_cosine_similarity(lhs, rhs)?, &[1.0 / 3.0, 1.0, 2.0 / 3.0, 8.0 / 9.0], ); Ok(()) @@ -698,7 +692,7 @@ mod tests { let rhs = constant_tensor_array(&[3], &[1.0, 1.0, 0.0], 3)?; let expected = 1.0 / 2.0_f64.sqrt(); assert_close( - &eval_cosine_similarity(lhs, rhs, 3)?, + &eval_cosine_similarity(lhs, rhs)?, &[expected, expected, expected], ); Ok(()) @@ -717,7 +711,7 @@ mod tests { 7.0, 8.0, 9.0, // ], )?; - assert_close(&eval_cosine_similarity(lhs, rhs, 3)?, &[0.0, 0.0, 0.0]); + assert_close(&eval_cosine_similarity(lhs, rhs)?, &[0.0, 0.0, 0.0]); Ok(()) } @@ -728,7 +722,7 @@ mod tests { // L2Denorm fast path's inner product yields 1. let lhs = constant_tensor_array(&[3], &[3.0, 4.0, 0.0], 5)?; let rhs = constant_tensor_array(&[3], &[3.0, 4.0, 0.0], 5)?; - assert_close(&eval_cosine_similarity(lhs, rhs, 5)?, &[1.0; 5]); + assert_close(&eval_cosine_similarity(lhs, rhs)?, &[1.0; 5]); Ok(()) } @@ -746,21 +740,17 @@ mod tests { ], )?; assert_close( - &eval_cosine_similarity(lhs, rhs, 4)?, + &eval_cosine_similarity(lhs, rhs)?, &[1.0 / 3.0, 1.0, 2.0 / 3.0, 8.0 / 9.0], ); Ok(()) } #[rstest] - #[case::vector(cosine_vector_lhs(), cosine_vector_rhs(), 2)] - #[case::fixed_shape_tensor(cosine_tensor_lhs(), cosine_tensor_rhs(), 2)] - fn serde_round_trip( - #[case] lhs: ArrayRef, - #[case] rhs: ArrayRef, - #[case] len: usize, - ) -> VortexResult<()> { - let original = CosineSimilarity::try_new_array(lhs.clone(), rhs.clone(), len)?.into_array(); + #[case::vector(cosine_vector_lhs(), cosine_vector_rhs())] + #[case::fixed_shape_tensor(cosine_tensor_lhs(), cosine_tensor_rhs())] + fn serde_round_trip(#[case] lhs: ArrayRef, #[case] rhs: ArrayRef) -> VortexResult<()> { + let original = CosineSimilarity::try_new_array(lhs.clone(), rhs.clone())?.into_array(); let plugin = ScalarFnArrayPlugin::new(CosineSimilarity); let metadata = plugin diff --git a/vortex-tensor/src/scalar_fns/inner_product.rs b/vortex-tensor/src/scalar_fns/inner_product.rs index b3ba7a7b557..3c6b7bf610a 100644 --- a/vortex-tensor/src/scalar_fns/inner_product.rs +++ b/vortex-tensor/src/scalar_fns/inner_product.rs @@ -82,8 +82,8 @@ impl InnerProduct { /// /// Returns an error if the [`ScalarFnArray`] cannot be constructed (e.g. due to dtype /// mismatches). - pub fn try_new_array(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult { - ScalarFnArray::try_new(InnerProduct::new().erased(), vec![lhs, rhs], len) + pub fn try_new_array(lhs: ArrayRef, rhs: ArrayRef) -> VortexResult { + ScalarFnArray::try_new(InnerProduct::new().erased(), vec![lhs, rhs]) } } @@ -250,7 +250,7 @@ impl InnerProduct { let norms_l: PrimitiveArray = norms_l.execute(ctx)?; let norms_r: PrimitiveArray = norms_r.execute(ctx)?; - let dot: PrimitiveArray = InnerProduct::try_new_array(normalized_l, normalized_r, len)? + let dot: PrimitiveArray = InnerProduct::try_new_array(normalized_l, normalized_r)? .into_array() .execute(ctx)?; @@ -280,7 +280,7 @@ impl InnerProduct { let (normalized, norms) = extract_l2_denorm_children(denorm_ref); let denorm_norms: PrimitiveArray = norms.execute(ctx)?; - let dot: PrimitiveArray = InnerProduct::try_new_array(normalized, plain_ref.clone(), len)? + let dot: PrimitiveArray = InnerProduct::try_new_array(normalized, plain_ref.clone())? .into_array() .execute(ctx)?; @@ -371,7 +371,7 @@ impl InnerProduct { // the rewritten tree if the sorf child is `Vector[FSL(Dict)]`. Termination is // guaranteed because the rewrite strictly removes a `SorfTransform` scalar-fn node // from the tree and SORFs cannot be nested. - let rewritten = InnerProduct::try_new_array(sorf_child, new_constant, len)? + let rewritten = InnerProduct::try_new_array(sorf_child, new_constant)? .into_array() .execute(ctx)?; Ok(Some(rewritten)) @@ -572,9 +572,9 @@ mod tests { use crate::utils::test_helpers::vector_array; /// Evaluates inner product between two tensor arrays and returns the result as `Vec`. - fn eval_inner_product(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult> { + fn eval_inner_product(lhs: ArrayRef, rhs: ArrayRef) -> VortexResult> { let scalar_fn = InnerProduct::new().erased(); - let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], len)?; + let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs])?; let mut ctx = SESSION.create_execution_ctx(); let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?; Ok(prim.as_slice::().to_vec()) @@ -598,7 +598,7 @@ mod tests { ) -> VortexResult<()> { let lhs = tensor_array(shape, lhs_elems)?; let rhs = tensor_array(shape, rhs_elems)?; - assert_close(&eval_inner_product(lhs, rhs, 1)?, expected); + assert_close(&eval_inner_product(lhs, rhs)?, expected); Ok(()) } @@ -620,7 +620,7 @@ mod tests { 2.0, 2.0, 2.0, // tensor 2: dot = 6 ], )?; - assert_close(&eval_inner_product(lhs, rhs, 3)?, &[0.0, 25.0, 6.0]); + assert_close(&eval_inner_product(lhs, rhs)?, &[0.0, 25.0, 6.0]); Ok(()) } @@ -640,7 +640,7 @@ mod tests { 0.0, 1.0, // vector 1: dot = 0 ], )?; - assert_close(&eval_inner_product(lhs, rhs, 2)?, &[25.0, 0.0]); + assert_close(&eval_inner_product(lhs, rhs)?, &[25.0, 0.0]); Ok(()) } @@ -652,7 +652,7 @@ mod tests { let lhs = MaskedArray::try_new(lhs, Validity::from_iter([true, false, true]))?.into_array(); let scalar_fn = InnerProduct::new().erased(); - let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 3)?; + let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs])?; let mut ctx = SESSION.create_execution_ctx(); let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?; @@ -669,7 +669,7 @@ mod tests { fn rejects_non_extension_dtype() { let lhs = PrimitiveArray::from_iter([1.0_f64, 2.0]).into_array(); let rhs = PrimitiveArray::from_iter([3.0_f64, 4.0]).into_array(); - let result = InnerProduct::try_new_array(lhs, rhs, 2); + let result = InnerProduct::try_new_array(lhs, rhs); assert!(result.is_err()); } @@ -677,7 +677,7 @@ mod tests { fn rejects_mismatched_dtypes() -> VortexResult<()> { let lhs = tensor_array(&[2], &[1.0_f64, 2.0])?; let rhs = vector_array(2, &[3.0_f64, 4.0])?; - let result = InnerProduct::try_new_array(lhs, rhs, 1); + let result = InnerProduct::try_new_array(lhs, rhs); assert!(result.is_err()); Ok(()) } @@ -692,7 +692,7 @@ mod tests { let rhs = l2_denorm_array(&[2], &[1.0, 0.0], &[1.0], &mut ctx)?; // Expected: 5.0 * 1.0 * dot([0.6, 0.8], [1.0, 0.0]) = 5.0 * 0.6 = 3.0. - assert_close(&eval_inner_product(lhs, rhs, 1)?, &[3.0]); + assert_close(&eval_inner_product(lhs, rhs)?, &[3.0]); Ok(()) } @@ -704,7 +704,7 @@ mod tests { let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 0.0, 1.0], &[5.0, 1.0], &mut ctx)?; - assert_close(&eval_inner_product(lhs, rhs, 2)?, &[25.0, 0.0]); + assert_close(&eval_inner_product(lhs, rhs)?, &[25.0, 0.0]); Ok(()) } @@ -717,7 +717,7 @@ mod tests { let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0], &mut ctx)?; let rhs = tensor_array(&[2], &[1.0, 2.0])?; - assert_close(&eval_inner_product(lhs, rhs, 1)?, &[11.0]); + assert_close(&eval_inner_product(lhs, rhs)?, &[11.0]); Ok(()) } @@ -730,7 +730,7 @@ mod tests { let lhs = tensor_array(&[2], &[1.0, 2.0])?; let rhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0], &mut ctx)?; - assert_close(&eval_inner_product(lhs, rhs, 1)?, &[11.0]); + assert_close(&eval_inner_product(lhs, rhs)?, &[11.0]); Ok(()) } @@ -741,11 +741,11 @@ mod tests { let norms_l = PrimitiveArray::from_option_iter([Some(5.0f64), None]).into_array(); let mut ctx = SESSION.create_execution_ctx(); - let lhs = L2Denorm::try_new_array(normalized_l, norms_l, 2, &mut ctx)?.into_array(); + let lhs = L2Denorm::try_new_array(normalized_l, norms_l, &mut ctx)?.into_array(); let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; let scalar_fn = InnerProduct::new().erased(); - let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 2)?; + let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs])?; let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?; // Row 0: 5.0 * 5.0 * dot([0.6, 0.8], [0.6, 0.8]) = 25.0, row 1: null. @@ -756,14 +756,10 @@ mod tests { } #[rstest] - #[case::vector(inner_product_vector_lhs(), inner_product_vector_rhs(), 2)] - #[case::fixed_shape_tensor(inner_product_tensor_lhs(), inner_product_tensor_rhs(), 2)] - fn serde_round_trip( - #[case] lhs: ArrayRef, - #[case] rhs: ArrayRef, - #[case] len: usize, - ) -> VortexResult<()> { - let original = InnerProduct::try_new_array(lhs.clone(), rhs.clone(), len)?.into_array(); + #[case::vector(inner_product_vector_lhs(), inner_product_vector_rhs())] + #[case::fixed_shape_tensor(inner_product_tensor_lhs(), inner_product_tensor_rhs())] + fn serde_round_trip(#[case] lhs: ArrayRef, #[case] rhs: ArrayRef) -> VortexResult<()> { + let original = InnerProduct::try_new_array(lhs.clone(), rhs.clone())?.into_array(); let plugin = ScalarFnArrayPlugin::new(InnerProduct); let metadata = plugin @@ -857,9 +853,9 @@ mod tests { } /// Execute an inner product and return the flat `f32` results. - fn eval_ip_f32(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult> { + fn eval_ip_f32(lhs: ArrayRef, rhs: ArrayRef) -> VortexResult> { let scalar_fn = InnerProduct::new().erased(); - let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], len)?; + let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs])?; let mut ctx = SESSION.create_execution_ctx(); let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?; Ok(prim.as_slice::().to_vec()) @@ -900,8 +896,7 @@ mod tests { dimensions: dim, element_ptype: PType::F32, }; - let sorf = - SorfTransform::try_new_array(&sorf_options, padded_vector, num_rows)?.into_array(); + let sorf = SorfTransform::try_new_array(&sorf_options, padded_vector)?.into_array(); Ok((sorf, codes, values, padded_dim)) } @@ -996,7 +991,7 @@ mod tests { }) .collect(); - let actual = eval_ip_f32(sorf_lhs, const_rhs, num_rows)?; + let actual = eval_ip_f32(sorf_lhs, const_rhs)?; assert_close_f32(&actual, &expected, 1e-3); Ok(()) } @@ -1033,7 +1028,7 @@ mod tests { }) .collect(); - let actual = eval_ip_f32(const_lhs, sorf, num_rows)?; + let actual = eval_ip_f32(const_lhs, sorf)?; assert_close_f32(&actual, &expected, 1e-3); Ok(()) } @@ -1071,7 +1066,7 @@ mod tests { }) .collect(); - let actual = eval_ip_f32(sorf, const_rhs, num_rows)?; + let actual = eval_ip_f32(sorf, const_rhs)?; assert_close_f32(&actual, &expected, 1e-3); Ok(()) } @@ -1090,7 +1085,7 @@ mod tests { let query_elems: Vec = vec![0.0; dim as usize]; let const_rhs = Vector::constant_array(&query_elems, num_rows)?; - let actual = eval_ip_f32(sorf, const_rhs, num_rows)?; + let actual = eval_ip_f32(sorf, const_rhs)?; assert_eq!(actual.len(), 0); Ok(()) } @@ -1124,7 +1119,7 @@ mod tests { }) .collect(); - let actual = eval_ip_f32(dict_lhs, const_rhs, num_rows)?; + let actual = eval_ip_f32(dict_lhs, const_rhs)?; assert_close_f32(&actual, &expected, 1e-5); Ok(()) } @@ -1154,7 +1149,7 @@ mod tests { }) .collect(); - let actual = eval_ip_f32(const_lhs, dict_rhs, num_rows)?; + let actual = eval_ip_f32(const_lhs, dict_rhs)?; assert_close_f32(&actual, &expected, 1e-5); Ok(()) } @@ -1203,7 +1198,7 @@ mod tests { }) .collect(); - let actual = eval_ip_f32(dict_lhs, const_rhs, num_rows)?; + let actual = eval_ip_f32(dict_lhs, const_rhs)?; assert_close_f32(&actual, &expected, 1e-5); Ok(()) } @@ -1231,7 +1226,7 @@ mod tests { }) .collect(); - let actual = eval_ip_f32(plain_lhs, const_rhs, num_rows)?; + let actual = eval_ip_f32(plain_lhs, const_rhs)?; assert_close_f32(&actual, &expected, 1e-5); Ok(()) } @@ -1249,7 +1244,7 @@ mod tests { let query: Vec = vec![0.0; 4]; let const_rhs = Vector::constant_array(&query, num_rows)?; - let actual = eval_ip_f32(dict_lhs, const_rhs, num_rows)?; + let actual = eval_ip_f32(dict_lhs, const_rhs)?; assert_eq!(actual.len(), 0); Ok(()) } @@ -1288,7 +1283,7 @@ mod tests { }) .collect(); - let actual = eval_ip_f32(sorf, const_rhs, num_rows)?; + let actual = eval_ip_f32(sorf, const_rhs)?; assert_close_f32(&actual, &expected, 1e-3); Ok(()) } @@ -1353,7 +1348,7 @@ mod tests { }) .collect(); - let actual = eval_ip_f32(dict_lhs, const_rhs, num_rows)?; + let actual = eval_ip_f32(dict_lhs, const_rhs)?; assert_close_f32(&actual, &expected, 1e-4); Ok(()) } @@ -1402,7 +1397,7 @@ mod tests { // Loose tolerance: the sorf transform works in f32 with a k-round butterfly, so // the rewrite path and the decoded path accumulate slightly different rounding // even though the math is equivalent in exact arithmetic. - let actual = eval_ip_f32(sorf, const_rhs, num_rows)?; + let actual = eval_ip_f32(sorf, const_rhs)?; assert_close_f32(&actual, &expected, 1e-2); Ok(()) } @@ -1445,7 +1440,7 @@ mod tests { // Tight tolerance here because no SorfTransform rotation is involved — the // arithmetic should agree bit-for-bit up to float reassociation. - let actual = eval_ip_f32(dict_lhs, const_rhs, num_rows)?; + let actual = eval_ip_f32(dict_lhs, const_rhs)?; assert_close_f32(&actual, &expected, 1e-4); Ok(()) } @@ -1476,8 +1471,7 @@ mod tests { dimensions: dim, element_ptype: PType::F32, }; - let sorf = - SorfTransform::try_new_array(&sorf_options, padded_vector, num_rows)?.into_array(); + let sorf = SorfTransform::try_new_array(&sorf_options, padded_vector)?.into_array(); let query: Vec = (0..dim).map(|_| rng.next_f32()).collect(); let const_rhs = Vector::constant_array(&query, num_rows)?; @@ -1495,7 +1489,7 @@ mod tests { .map(|i| naive_dot(&decoded[i * dim as usize..(i + 1) * dim as usize], &query)) .collect(); - let actual = eval_ip_f32(sorf, const_rhs, num_rows)?; + let actual = eval_ip_f32(sorf, const_rhs)?; assert_close_f32(&actual, &expected, 1e-2); // Also verify the max relative error is small. The SORF rotation does not @@ -1545,7 +1539,7 @@ mod tests { .map(|i| naive_dot(&decoded[i * dim as usize..(i + 1) * dim as usize], &query)) .collect(); - let actual = eval_ip_f32(sorf, const_rhs, num_rows)?; + let actual = eval_ip_f32(sorf, const_rhs)?; assert_close_f32(&actual, &expected, 1e-2); Ok(()) } @@ -1579,7 +1573,7 @@ mod tests { .map(|i| naive_dot(&decoded[i * dim as usize..(i + 1) * dim as usize], &query)) .collect(); - let actual = eval_ip_f32(const_lhs, sorf, num_rows)?; + let actual = eval_ip_f32(const_lhs, sorf)?; assert_close_f32(&actual, &expected, 1e-2); Ok(()) } diff --git a/vortex-tensor/src/scalar_fns/l2_denorm.rs b/vortex-tensor/src/scalar_fns/l2_denorm.rs index a814af33f33..ebcf7be40b7 100644 --- a/vortex-tensor/src/scalar_fns/l2_denorm.rs +++ b/vortex-tensor/src/scalar_fns/l2_denorm.rs @@ -111,13 +111,12 @@ impl L2Denorm { pub fn try_new_array( normalized: ArrayRef, norms: ArrayRef, - len: usize, ctx: &mut ExecutionCtx, ) -> VortexResult { validate_l2_normalized_rows_against_norms(&normalized, Some(&norms), ctx)?; // SAFETY: We just validated that it is normalized. - unsafe { Self::new_array_unchecked(normalized, norms, len) } + unsafe { Self::new_array_unchecked(normalized, norms) } } /// Constructs an [`L2Denorm`] array without validating that the `normalized` child is actually @@ -140,9 +139,8 @@ impl L2Denorm { pub unsafe fn new_array_unchecked( normalized: ArrayRef, norms: ArrayRef, - len: usize, ) -> VortexResult { - ScalarFnArray::try_new(L2Denorm::new().erased(), vec![normalized, norms], len) + ScalarFnArray::try_new(L2Denorm::new().erased(), vec![normalized, norms]) } } @@ -432,7 +430,7 @@ pub fn normalize_as_l2_denorm( } // Calculate the norms of the vectors. - let norms_sfn = L2Norm::try_new_array(input.clone(), row_count)?; + let norms_sfn = L2Norm::try_new_array(input.clone())?; let norms_array: ArrayRef = norms_sfn.into_array().execute(ctx)?; let primitive_norms: PrimitiveArray = norms_array.clone().execute(ctx)?; let norms_validity = primitive_norms.validity()?; @@ -487,7 +485,7 @@ pub fn normalize_as_l2_denorm( // construction. // - Null rows are zeroed out above to avoid propagating arbitrary physical storage values into // downstream lossy encodings. - unsafe { L2Denorm::new_array_unchecked(normalized, norms_array, row_count) } + unsafe { L2Denorm::new_array_unchecked(normalized, norms_array) } } /// Attempts to build an [`L2Denorm`] whose two children are both [`ConstantArray`]s by eagerly @@ -570,7 +568,7 @@ pub(crate) fn try_build_constant_l2_denorm( // point tolerance) or all zeros when `||v|| == 0`. Stored norms are non-negative by // construction (`sqrt`). These are exactly the invariants required by // [`L2Denorm::new_array_unchecked`]. - let wrapped = unsafe { L2Denorm::new_array_unchecked(normalized_ext, norms_array, len)? }; + let wrapped = unsafe { L2Denorm::new_array_unchecked(normalized_ext, norms_array)? }; Ok(Some(wrapped)) } @@ -771,9 +769,9 @@ mod tests { use crate::utils::test_helpers::vector_array; /// Evaluates L2 denorm on a tensor/vector array and returns the executed array. - fn eval_l2_denorm(normalized: ArrayRef, norms: ArrayRef, len: usize) -> VortexResult { + fn eval_l2_denorm(normalized: ArrayRef, norms: ArrayRef) -> VortexResult { let mut ctx = SESSION.create_execution_ctx(); - let result = L2Denorm::try_new_array(normalized, norms, len, &mut ctx)?; + let result = L2Denorm::try_new_array(normalized, norms, &mut ctx)?; result.into_array().execute(&mut ctx) } @@ -813,7 +811,7 @@ mod tests { fn l2_denorm_vectors() -> VortexResult<()> { let lhs = vector_array(3, &[0.6, 0.8, 0.0, 0.0, 0.0, 0.0])?; let rhs = PrimitiveArray::from_iter([5.0f64, 0.0]).into_array(); - let actual = eval_l2_denorm(lhs, rhs, 2)?; + let actual = eval_l2_denorm(lhs, rhs)?; let expected = vector_array(3, &[3.0, 4.0, 0.0, 0.0, 0.0, 0.0])?; assert_tensor_arrays_eq(actual, expected)?; @@ -824,7 +822,7 @@ mod tests { fn l2_denorm_fixed_shape_tensors() -> VortexResult<()> { let lhs = tensor_array(&[2, 2], &[0.5, 0.5, 0.5, 0.5, 1.0, 0.0, 0.0, 0.0])?; let rhs = PrimitiveArray::from_iter([4.0f64, 2.0]).into_array(); - let actual = eval_l2_denorm(lhs, rhs, 2)?; + let actual = eval_l2_denorm(lhs, rhs)?; let expected = tensor_array(&[2, 2], &[2.0, 2.0, 2.0, 2.0, 2.0, 0.0, 0.0, 0.0])?; assert_tensor_arrays_eq(actual, expected)?; @@ -838,7 +836,7 @@ mod tests { let rhs = PrimitiveArray::from_option_iter([Some(5.0f64), Some(2.0), None]).into_array(); let mut ctx = SESSION.create_execution_ctx(); - let actual: ExtensionArray = eval_l2_denorm(lhs, rhs, 3)?.execute(&mut ctx)?; + let actual: ExtensionArray = eval_l2_denorm(lhs, rhs)?.execute(&mut ctx)?; let storage: FixedSizeListArray = actual.storage_array().clone().execute(&mut ctx)?; let elements: PrimitiveArray = storage.elements().clone().execute(&mut ctx)?; @@ -855,7 +853,7 @@ mod tests { let rhs = PrimitiveArray::from_iter([1.0f64, 1.0]).into_array(); let mut ctx = SESSION.create_execution_ctx(); - let result = L2Denorm::try_new_array(lhs, rhs, 2, &mut ctx); + let result = L2Denorm::try_new_array(lhs, rhs, &mut ctx); assert!(result.is_err()); } @@ -865,7 +863,7 @@ mod tests { let rhs = PrimitiveArray::from_iter([1.0f64, 1.0]).into_array(); let mut ctx = SESSION.create_execution_ctx(); - let result = L2Denorm::try_new_array(lhs, rhs, 2, &mut ctx); + let result = L2Denorm::try_new_array(lhs, rhs, &mut ctx); assert!(result.is_err()); Ok(()) } @@ -876,7 +874,7 @@ mod tests { let rhs = PrimitiveArray::from_iter([1.0f64, 1.0]).into_array(); let mut ctx = SESSION.create_execution_ctx(); - let result = L2Denorm::try_new_array(lhs, rhs, 2, &mut ctx); + let result = L2Denorm::try_new_array(lhs, rhs, &mut ctx); assert!(result.is_err()); Ok(()) } @@ -887,7 +885,7 @@ mod tests { let rhs = PrimitiveArray::from_iter([1.0f32, 1.0]).into_array(); let mut ctx = SESSION.create_execution_ctx(); - let result = L2Denorm::try_new_array(lhs, rhs, 2, &mut ctx); + let result = L2Denorm::try_new_array(lhs, rhs, &mut ctx); assert!(result.is_err()); Ok(()) } @@ -916,7 +914,7 @@ mod tests { let norms = PrimitiveArray::from_iter([5.0f64, 1.0]).into_array(); let mut ctx = SESSION.create_execution_ctx(); - let result = L2Denorm::try_new_array(normalized, norms, 2, &mut ctx); + let result = L2Denorm::try_new_array(normalized, norms, &mut ctx); assert!(result.is_err()); Ok(()) } @@ -927,7 +925,7 @@ mod tests { let norms = PrimitiveArray::from_iter([0.0f64, 0.0]).into_array(); let mut ctx = SESSION.create_execution_ctx(); - let result = L2Denorm::try_new_array(normalized, norms, 2, &mut ctx); + let result = L2Denorm::try_new_array(normalized, norms, &mut ctx); assert!(result.is_err()); Ok(()) } @@ -938,7 +936,7 @@ mod tests { let norms = PrimitiveArray::from_iter([1.0f64, -1.0]).into_array(); let mut ctx = SESSION.create_execution_ctx(); - let result = L2Denorm::try_new_array(normalized, norms, 2, &mut ctx); + let result = L2Denorm::try_new_array(normalized, norms, &mut ctx); assert!(result.is_err()); Ok(()) } @@ -948,7 +946,7 @@ mod tests { let normalized = vector_array(2, &[3.0, 4.0, 1.0, 0.0])?; let norms = PrimitiveArray::from_iter([5.0f64, 1.0]).into_array(); - let result = unsafe { L2Denorm::new_array_unchecked(normalized, norms, 2) }; + let result = unsafe { L2Denorm::new_array_unchecked(normalized, norms) }; assert!(result.is_ok()); Ok(()) } @@ -1062,7 +1060,7 @@ mod tests { let normalized = vector_array(3, &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0])?; let norms = constant_f64_norms(1.0, 2); - let actual = eval_l2_denorm(normalized.clone(), norms, 2)?; + let actual = eval_l2_denorm(normalized.clone(), norms)?; assert_tensor_arrays_eq(actual, normalized)?; Ok(()) } @@ -1074,7 +1072,7 @@ mod tests { let normalized = vector_array(3, &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0])?; let norms = constant_f64_norms(1.0 + 1e-12, 2); - let actual = eval_l2_denorm(normalized.clone(), norms, 2)?; + let actual = eval_l2_denorm(normalized.clone(), norms)?; assert_tensor_arrays_eq(actual, normalized)?; Ok(()) } @@ -1086,7 +1084,7 @@ mod tests { let normalized = vector_array(3, &[0.6, 0.8, 0.0, 1.0, 0.0, 0.0])?; let norms = constant_f64_norms(5.0, 2); - let actual = eval_l2_denorm(normalized, norms, 2)?; + let actual = eval_l2_denorm(normalized, norms)?; let expected = vector_array(3, &[3.0, 4.0, 0.0, 5.0, 0.0, 0.0])?; assert_tensor_arrays_eq(actual, expected)?; Ok(()) @@ -1099,7 +1097,7 @@ mod tests { let normalized = tensor_array(&[2, 2], &[0.5, 0.5, 0.5, 0.5, 1.0, 0.0, 0.0, 0.0])?; let norms = constant_f64_norms(4.0, 2); - let actual = eval_l2_denorm(normalized, norms, 2)?; + let actual = eval_l2_denorm(normalized, norms)?; let expected = tensor_array(&[2, 2], &[2.0, 2.0, 2.0, 2.0, 4.0, 0.0, 0.0, 0.0])?; assert_tensor_arrays_eq(actual, expected)?; Ok(()) diff --git a/vortex-tensor/src/scalar_fns/l2_norm.rs b/vortex-tensor/src/scalar_fns/l2_norm.rs index 59d49fc8e1a..11918e08ed5 100644 --- a/vortex-tensor/src/scalar_fns/l2_norm.rs +++ b/vortex-tensor/src/scalar_fns/l2_norm.rs @@ -75,8 +75,8 @@ impl L2Norm { /// /// Returns an error if the [`ScalarFnArray`] cannot be constructed (e.g. due to dtype /// mismatches). - pub fn try_new_array(child: ArrayRef, len: usize) -> VortexResult { - ScalarFnArray::try_new(L2Norm::new().erased(), vec![child], len) + pub fn try_new_array(child: ArrayRef) -> VortexResult { + ScalarFnArray::try_new(L2Norm::new().erased(), vec![child]) } } @@ -282,9 +282,9 @@ mod tests { use crate::utils::test_helpers::vector_array; /// Evaluates L2 norm on a tensor/vector array and returns the result as `Vec`. - fn eval_l2_norm(input: ArrayRef, len: usize) -> VortexResult> { + fn eval_l2_norm(input: ArrayRef) -> VortexResult> { let scalar_fn = L2Norm::new().erased(); - let result = ScalarFnArray::try_new(scalar_fn, vec![input], len)?; + let result = ScalarFnArray::try_new(scalar_fn, vec![input])?; let mut ctx = SESSION.create_execution_ctx(); let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?; Ok(prim.as_slice::().to_vec()) @@ -301,7 +301,7 @@ mod tests { #[case] expected: &[f64], ) -> VortexResult<()> { let arr = tensor_array(shape, elements)?; - assert_close(&eval_l2_norm(arr, 1)?, expected); + assert_close(&eval_l2_norm(arr)?, expected); Ok(()) } @@ -315,7 +315,7 @@ mod tests { 1.0, 1.0, 1.0, // norm = sqrt(3) ], )?; - assert_close(&eval_l2_norm(arr, 3)?, &[5.0, 0.0, 3.0_f64.sqrt()]); + assert_close(&eval_l2_norm(arr)?, &[5.0, 0.0, 3.0_f64.sqrt()]); Ok(()) } @@ -328,7 +328,7 @@ mod tests { 3.0, 4.0, 0.0, // norm = 5.0 ], )?; - assert_close(&eval_l2_norm(arr, 2)?, &[1.0, 5.0]); + assert_close(&eval_l2_norm(arr)?, &[1.0, 5.0]); Ok(()) } @@ -339,7 +339,7 @@ mod tests { let arr = MaskedArray::try_new(arr, Validity::from_iter([true, false]))?.into_array(); let scalar_fn = L2Norm::new().erased(); - let result = ScalarFnArray::try_new(scalar_fn, vec![arr], 2)?; + let result = ScalarFnArray::try_new(scalar_fn, vec![arr])?; let mut ctx = SESSION.create_execution_ctx(); let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?; @@ -359,7 +359,7 @@ mod tests { let input = literal_vector_array(&[3.0f64, 4.0], 4); let scalar_fn = L2Norm::new().erased(); - let result = ScalarFnArray::try_new(scalar_fn, vec![input], 4)?.into_array(); + let result = ScalarFnArray::try_new(scalar_fn, vec![input])?.into_array(); let mut ctx = SESSION.create_execution_ctx(); let output = result.execute_until::(&mut ctx)?; @@ -390,7 +390,7 @@ mod tests { let input = ConstantArray::new(null_scalar, 3).into_array(); let scalar_fn = L2Norm::new().erased(); - let result = ScalarFnArray::try_new(scalar_fn, vec![input], 3)?.into_array(); + let result = ScalarFnArray::try_new(scalar_fn, vec![input])?.into_array(); let mut ctx = SESSION.create_execution_ctx(); let output = result.execute_until::(&mut ctx)?; @@ -407,10 +407,10 @@ mod tests { } #[rstest] - #[case::fixed_shape_tensor(l2_norm_tensor_child(), 2)] - #[case::vector(l2_norm_vector_child(), 2)] - fn serde_round_trip(#[case] child: ArrayRef, #[case] len: usize) -> VortexResult<()> { - let original = L2Norm::try_new_array(child.clone(), len)?.into_array(); + #[case::fixed_shape_tensor(l2_norm_tensor_child())] + #[case::vector(l2_norm_vector_child())] + fn serde_round_trip(#[case] child: ArrayRef) -> VortexResult<()> { + let original = L2Norm::try_new_array(child.clone())?.into_array(); let plugin = ScalarFnArrayPlugin::new(L2Norm); let metadata = plugin diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/mod.rs b/vortex-tensor/src/scalar_fns/sorf_transform/mod.rs index 26d38e87a1e..97847a6771c 100644 --- a/vortex-tensor/src/scalar_fns/sorf_transform/mod.rs +++ b/vortex-tensor/src/scalar_fns/sorf_transform/mod.rs @@ -106,14 +106,10 @@ impl SorfTransform { /// `options.element_ptype`. /// /// [`Vector`]: crate::vector::Vector - pub fn try_new_array( - options: &SorfOptions, - child: ArrayRef, - len: usize, - ) -> VortexResult { + pub fn try_new_array(options: &SorfOptions, child: ArrayRef) -> VortexResult { validate_sorf_options(options)?; - ScalarFnArray::try_new(SorfTransform::new(options).erased(), vec![child], len) + ScalarFnArray::try_new(SorfTransform::new(options).erased(), vec![child]) } } diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs b/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs index 5efe1436ed6..7fdaee28501 100644 --- a/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs +++ b/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs @@ -122,12 +122,8 @@ fn default_options(dim: u32, seed: u64) -> SorfOptions { } /// Execute a `SorfTransform` array and return the decoded flat f32 elements. -fn execute_sorf( - options: &SorfOptions, - child: ExtensionArray, - num_rows: usize, -) -> VortexResult> { - let sorf = SorfTransform::try_new_array(options, child.into_array(), num_rows)?; +fn execute_sorf(options: &SorfOptions, child: ExtensionArray) -> VortexResult> { + let sorf = SorfTransform::try_new_array(options, child.into_array())?; let mut ctx = SESSION.create_execution_ctx(); let result: ExtensionArray = sorf.into_array().execute(&mut ctx)?; let result_fsl: FixedSizeListArray = result.storage_array().clone().execute(&mut ctx)?; @@ -150,7 +146,7 @@ fn roundtrip_recovery() -> VortexResult<()> { let seed = 42u64; let (input_f32, padded_vector, _) = forward_rotate_and_quantize(dim, num_rows, seed, 3, 8)?; let options = default_options(dim as u32, seed); - let result = execute_sorf(&options, padded_vector, num_rows)?; + let result = execute_sorf(&options, padded_vector)?; assert_eq!(result.len(), num_rows * dim); @@ -182,7 +178,7 @@ fn empty_array_non_nullable() -> VortexResult<()> { // Build an empty Vector child. let child = empty_padded_vector(padded_dim, Validity::NonNullable)?; - let sorf = SorfTransform::try_new_array(&options, child.into_array(), 0)?; + let sorf = SorfTransform::try_new_array(&options, child.into_array())?; let mut ctx = SESSION.create_execution_ctx(); let result: ExtensionArray = sorf.into_array().execute(&mut ctx)?; @@ -204,7 +200,7 @@ fn empty_array_nullable() -> VortexResult<()> { // Build an empty but nullable Vector child. let child = empty_padded_vector(padded_dim, Validity::from(Nullability::Nullable))?; - let sorf = SorfTransform::try_new_array(&options, child.into_array(), 0)?; + let sorf = SorfTransform::try_new_array(&options, child.into_array())?; let mut ctx = SESSION.create_execution_ctx(); let result: ExtensionArray = sorf.into_array().execute(&mut ctx)?; @@ -241,7 +237,7 @@ fn nullable_validity_propagation() -> VortexResult<()> { let nullable_vector = wrap_as_vector(fsl_nullable, validity.clone())?; let options = default_options(dim as u32, seed); - let sorf = SorfTransform::try_new_array(&options, nullable_vector.into_array(), num_rows)?; + let sorf = SorfTransform::try_new_array(&options, nullable_vector.into_array())?; let mut ctx = SESSION.create_execution_ctx(); let result: ExtensionArray = sorf.into_array().execute(&mut ctx)?; let result_fsl: FixedSizeListArray = result.storage_array().clone().execute(&mut ctx)?; @@ -270,7 +266,7 @@ fn dimension_truncation() -> VortexResult<()> { assert_eq!(padded_dim, 256, "200 should pad to 256"); let options = default_options(dim as u32, seed); - let result = execute_sorf(&options, padded_vector, num_rows)?; + let result = execute_sorf(&options, padded_vector)?; // Output should have original dimension, not padded. assert_eq!(result.len(), num_rows * dim); @@ -324,7 +320,7 @@ fn rejects_zero_rounds_at_construction() { let child = FixedSizeListArray::try_new(elements, 128, Validity::NonNullable, 1) .expect("test child should be valid"); - let err = SorfTransform::try_new_array(&options, child.into_array(), 1) + let err = SorfTransform::try_new_array(&options, child.into_array()) .expect_err("zero rounds should be rejected at construction time"); assert!(err.to_string().contains("num_rounds")); } @@ -341,7 +337,7 @@ fn rejects_non_float_output_ptype_at_construction() { let child = FixedSizeListArray::try_new(elements, 128, Validity::NonNullable, 1) .expect("test child should be valid"); - let err = SorfTransform::try_new_array(&options, child.into_array(), 1) + let err = SorfTransform::try_new_array(&options, child.into_array()) .expect_err("non-float output ptypes should be rejected at construction time"); assert!(err.to_string().contains("element_ptype")); } @@ -354,7 +350,7 @@ fn rejects_non_vector_extension_child_at_construction() { let child = FixedSizeListArray::try_new(elements, 128, Validity::NonNullable, 1) .expect("test child should be valid"); - let err = SorfTransform::try_new_array(&options, child.into_array(), 1) + let err = SorfTransform::try_new_array(&options, child.into_array()) .expect_err("non-Vector-extension children should be rejected at construction time"); assert!(err.to_string().contains("Vector extension")); } @@ -368,7 +364,7 @@ fn rejects_wrong_padded_dimension_at_construction() { .expect("test child should be valid"); let child = wrap_as_vector(fsl, Validity::NonNullable).expect("wrap should succeed"); - let err = SorfTransform::try_new_array(&options, child.into_array(), 1) + let err = SorfTransform::try_new_array(&options, child.into_array()) .expect_err("mismatched padded dimension should be rejected at construction time"); assert!(err.to_string().contains("dimension")); } @@ -383,7 +379,7 @@ fn rejects_non_f32_child_storage_at_construction() { .expect("test child should be valid"); let child = wrap_as_vector(fsl, Validity::NonNullable).expect("wrap should succeed"); - let err = SorfTransform::try_new_array(&options, child.into_array(), 1) + let err = SorfTransform::try_new_array(&options, child.into_array()) .expect_err("non-f32 Vector storage should be rejected at construction time"); assert!(err.to_string().contains("f32")); } @@ -401,7 +397,7 @@ fn f16_output_type() -> VortexResult<()> { dimensions: dim as u32, element_ptype: PType::F16, }; - let sorf = SorfTransform::try_new_array(&options, padded_vector.into_array(), num_rows)?; + let sorf = SorfTransform::try_new_array(&options, padded_vector.into_array())?; let mut ctx = SESSION.create_execution_ctx(); let result: ExtensionArray = sorf.into_array().execute(&mut ctx)?; let result_fsl: FixedSizeListArray = result.storage_array().clone().execute(&mut ctx)?; @@ -426,7 +422,7 @@ fn f64_output_type() -> VortexResult<()> { dimensions: dim as u32, element_ptype: PType::F64, }; - let sorf = SorfTransform::try_new_array(&options, padded_vector.into_array(), num_rows)?; + let sorf = SorfTransform::try_new_array(&options, padded_vector.into_array())?; let mut ctx = SESSION.create_execution_ctx(); let result: ExtensionArray = sorf.into_array().execute(&mut ctx)?; let result_fsl: FixedSizeListArray = result.storage_array().clone().execute(&mut ctx)?; @@ -469,7 +465,7 @@ fn serde_round_trip(#[case] dimensions: u32, #[case] validity: Validity) -> Vort element_ptype: PType::F32, }; let child = trivial_padded_vector(padded_dim, num_rows, validity); - let original = SorfTransform::try_new_array(&options, child.clone(), num_rows)?.into_array(); + let original = SorfTransform::try_new_array(&options, child.clone())?.into_array(); let plugin = ScalarFnArrayPlugin::new(SorfTransform); let metadata = plugin diff --git a/vortex-tensor/src/utils.rs b/vortex-tensor/src/utils.rs index 88845f4b767..f068881adf3 100644 --- a/vortex-tensor/src/utils.rs +++ b/vortex-tensor/src/utils.rs @@ -408,11 +408,10 @@ pub mod test_helpers { norms: &[T], ctx: &mut ExecutionCtx, ) -> VortexResult { - let len = norms.len(); let normalized = tensor_array(shape, normalized_elements)?; let norms = PrimitiveArray::new(Buffer::copy_from(norms), Validity::NonNullable).into_array(); - Ok(L2Denorm::try_new_array(normalized, norms, len, ctx)?.into_array()) + Ok(L2Denorm::try_new_array(normalized, norms, ctx)?.into_array()) } /// Asserts that each element in `actual` is within `1e-10` of the corresponding `expected` diff --git a/vortex-tensor/src/vector_search.rs b/vortex-tensor/src/vector_search.rs index 2a036accdfe..d55ab6e9eb0 100644 --- a/vortex-tensor/src/vector_search.rs +++ b/vortex-tensor/src/vector_search.rs @@ -83,7 +83,7 @@ pub fn build_similarity_search_tree>( let num_rows = data.len(); let query_vec = Vector::constant_array(query, num_rows)?; - let cosine = CosineSimilarity::try_new_array(data, query_vec, num_rows)?.into_array(); + let cosine = CosineSimilarity::try_new_array(data, query_vec)?.into_array(); let threshold_scalar = Scalar::primitive(threshold, Nullability::NonNullable); let threshold_array = ConstantArray::new(threshold_scalar, num_rows).into_array(); diff --git a/vortex-turboquant/src/scalar_fns/decode.rs b/vortex-turboquant/src/scalar_fns/decode.rs index 9d4d465dbe8..bb47f5f0eb4 100644 --- a/vortex-turboquant/src/scalar_fns/decode.rs +++ b/vortex-turboquant/src/scalar_fns/decode.rs @@ -58,8 +58,7 @@ impl TQDecode { /// Constructs a [`ScalarFnArray`] that lazily decodes a `TurboQuant` child into a `Vector`. pub fn try_new_array(child: ArrayRef) -> VortexResult { - let len = child.len(); - ScalarFnArray::try_new(TQDecode::new().erased(), vec![child], len) + ScalarFnArray::try_new(TQDecode::new().erased(), vec![child]) } } diff --git a/vortex-turboquant/src/scalar_fns/encode.rs b/vortex-turboquant/src/scalar_fns/encode.rs index 6dd16e4bb66..097f00be0ba 100644 --- a/vortex-turboquant/src/scalar_fns/encode.rs +++ b/vortex-turboquant/src/scalar_fns/encode.rs @@ -66,8 +66,7 @@ impl TQEncode { child: ArrayRef, config: &TurboQuantConfig, ) -> VortexResult { - let len = child.len(); - ScalarFnArray::try_new(TQEncode::new(config).erased(), vec![child], len) + ScalarFnArray::try_new(TQEncode::new(config).erased(), vec![child]) } } diff --git a/vortex-turboquant/src/vector/normalize.rs b/vortex-turboquant/src/vector/normalize.rs index c0a5c9f6f06..557e7f12541 100644 --- a/vortex-turboquant/src/vector/normalize.rs +++ b/vortex-turboquant/src/vector/normalize.rs @@ -52,7 +52,7 @@ pub(crate) fn tq_normalize_as_l2_denorm( let vector_validity = input.validity()?; // Use `L2Norm` to calculate the normals for each vector. - let norms: ArrayRef = L2Norm::try_new_array(input.clone(), row_count)? + let norms: ArrayRef = L2Norm::try_new_array(input.clone())? .into_array() .execute(ctx)?; let primitive_norms: PrimitiveArray = norms.clone().execute(ctx)?; @@ -84,7 +84,7 @@ pub(crate) fn tq_normalize_as_l2_denorm( // match the vector element type and row count. Valid nonzero rows are divided by their stored // norm and are unit-norm. Valid zero-norm rows and invalid rows use physical zero placeholders; // invalid rows remain guarded by row-level invalid validity. - unsafe { L2Denorm::new_array_unchecked(normalized, norms, row_count) } + unsafe { L2Denorm::new_array_unchecked(normalized, norms) } } fn normalize_vectors(