diff --git a/encodings/datetime-parts/src/compute/rules.rs b/encodings/datetime-parts/src/compute/rules.rs index 46bc2221e5d..bba19ceea47 100644 --- a/encodings/datetime-parts/src/compute/rules.rs +++ b/encodings/datetime-parts/src/compute/rules.rs @@ -133,10 +133,9 @@ impl ArrayParentReduceRule for DTPComparisonPushDownRule { } } - let result = - ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children, parent.len())? - .into_array() - .optimize()?; + let result = ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children)? + .into_array() + .optimize()?; Ok(Some(result)) } diff --git a/encodings/runend/src/rules.rs b/encodings/runend/src/rules.rs index cd80e42db46..ac7d6038146 100644 --- a/encodings/runend/src/rules.rs +++ b/encodings/runend/src/rules.rs @@ -77,8 +77,7 @@ impl ArrayParentReduceRule for RunEndScalarFnRule { } let new_values = - ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children, values_len)? - .into_array(); + ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children)?.into_array(); Ok(Some( unsafe { diff --git a/vortex-array/src/arrays/chunked/compute/rules.rs b/vortex-array/src/arrays/chunked/compute/rules.rs index d8d324a8e86..712578255bb 100644 --- a/vortex-array/src/arrays/chunked/compute/rules.rs +++ b/vortex-array/src/arrays/chunked/compute/rules.rs @@ -48,13 +48,9 @@ impl ArrayParentReduceRule for ChunkedUnaryScalarFnPushDownRule { let new_chunks: Vec<_> = array .iter_chunks() .map(|chunk| { - ScalarFnArray::try_new( - parent.scalar_fn().clone(), - vec![chunk.clone()], - chunk.len(), - )? - .into_array() - .optimize() + ScalarFnArray::try_new(parent.scalar_fn().clone(), vec![chunk.clone()])? + .into_array() + .optimize() }) .try_collect()?; @@ -104,7 +100,7 @@ impl ArrayParentReduceRule for ChunkedConstantScalarFnPushDownRule { }) .collect(); - ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children, chunk.len())? + ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children)? .into_array() .optimize() }) diff --git a/vortex-array/src/arrays/dict/compute/rules.rs b/vortex-array/src/arrays/dict/compute/rules.rs index 5150a22271a..024240e4030 100644 --- a/vortex-array/src/arrays/dict/compute/rules.rs +++ b/vortex-array/src/arrays/dict/compute/rules.rs @@ -126,10 +126,9 @@ impl ArrayParentReduceRule for DictionaryScalarFnValuesPushDownRule { } } - let new_values = - ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children, values_len)? - .into_array() - .optimize()?; + let new_values = ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children)? + .into_array() + .optimize()?; // We can only push down null-sensitive functions when we have all-valid codes. // In these cases, we cannot have the codes influence the nullability of the output DType. @@ -192,13 +191,9 @@ impl ArrayParentReduceRule for DictionaryScalarFnCodesPullUpRule { } } - let new_values = ScalarFnArray::try_new( - parent.scalar_fn().clone(), - new_children, - array.values().len(), - )? - .into_array() - .optimize()?; + let new_values = ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children)? + .into_array() + .optimize()?; let new_dict = unsafe { DictArray::new_unchecked(array.codes().clone(), new_values) }.into_array(); diff --git a/vortex-array/src/arrays/fixed_size_list/tests/nested.rs b/vortex-array/src/arrays/fixed_size_list/tests/nested.rs index efd6f630580..8fca62c0097 100644 --- a/vortex-array/src/arrays/fixed_size_list/tests/nested.rs +++ b/vortex-array/src/arrays/fixed_size_list/tests/nested.rs @@ -270,7 +270,6 @@ fn test_fsl_of_fsl_with_nulls() { #[test] fn test_deeply_nested_fsl() { - let _len = 2; let list_size = 2; // Create a 3-level nested FSL: FSL[FSL[FSL[i32]]]. diff --git a/vortex-array/src/arrays/scalar_fn/array.rs b/vortex-array/src/arrays/scalar_fn/array.rs index 2c77be3c329..a999a25cd51 100644 --- a/vortex-array/src/arrays/scalar_fn/array.rs +++ b/vortex-array/src/arrays/scalar_fn/array.rs @@ -6,6 +6,7 @@ use std::fmt::Formatter; use vortex_error::VortexExpect; use vortex_error::VortexResult; +use vortex_error::vortex_bail; use vortex_error::vortex_ensure; use crate::ArrayRef; @@ -30,19 +31,6 @@ impl Display for ScalarFnData { } impl ScalarFnData { - /// Create a new ScalarFnArray from a scalar function and its children. - pub fn build( - scalar_fn: ScalarFnRef, - children: Vec, - len: usize, - ) -> VortexResult { - vortex_ensure!( - children.iter().all(|c| c.len() == len), - "ScalarFnArray must have children equal to the array length" - ); - Ok(Self { scalar_fn }) - } - /// Get the scalar function bound to this array. #[inline(always)] pub fn scalar_fn(&self) -> &ScalarFnRef { @@ -85,14 +73,26 @@ impl> ScalarFnArrayExt for T {} impl Array { /// Create a new ScalarFnArray from a scalar function and its children. - pub fn try_new( + pub fn try_new(scalar_fn: ScalarFnRef, children: Vec) -> VortexResult { + let len = Self::infer_len(&children)?; + Self::try_new_with_len(scalar_fn, children, len) + } + + /// Create a new ScalarFnArray from a scalar function, children, and an explicit length. + /// + /// This is needed for zero-child scalar functions and deserialization paths where there is no + /// child array to infer the length from. + pub fn try_new_with_len( scalar_fn: ScalarFnRef, children: Vec, len: usize, ) -> VortexResult { + Self::validate_children_len(&children, len)?; let arg_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect(); let dtype = scalar_fn.return_dtype(&arg_dtypes)?; - let data = ScalarFnData::build(scalar_fn.clone(), children.clone(), len)?; + let data = ScalarFnData { + scalar_fn: scalar_fn.clone(), + }; let vtable = ScalarFn { id: scalar_fn.id() }; Ok(unsafe { Array::from_parts_unchecked( @@ -101,4 +101,19 @@ impl Array { ) }) } + + fn infer_len(children: &[ArrayRef]) -> VortexResult { + let Some(child) = children.first() else { + vortex_bail!("ScalarFnArray length cannot be inferred without children"); + }; + Ok(child.len()) + } + + fn validate_children_len(children: &[ArrayRef], len: usize) -> VortexResult<()> { + vortex_ensure!( + children.iter().all(|c| c.len() == len), + "ScalarFnArray must have children equal to the array length" + ); + Ok(()) + } } diff --git a/vortex-array/src/arrays/scalar_fn/plugin.rs b/vortex-array/src/arrays/scalar_fn/plugin.rs index c685503a370..044f79362cd 100644 --- a/vortex-array/src/arrays/scalar_fn/plugin.rs +++ b/vortex-array/src/arrays/scalar_fn/plugin.rs @@ -82,7 +82,7 @@ impl ArrayPlugin for ScalarFnArrayPlugi let parts = ::deserialize( &self.0, dtype, len, metadata, children, session, )?; - Ok(ScalarFnArray::try_new( + Ok(ScalarFnArray::try_new_with_len( TypedScalarFnInstance::new(self.0.clone(), parts.options).erased(), parts.children, len, diff --git a/vortex-array/src/arrays/scalar_fn/rules.rs b/vortex-array/src/arrays/scalar_fn/rules.rs index 1e9563cf9de..abc8e18e762 100644 --- a/vortex-array/src/arrays/scalar_fn/rules.rs +++ b/vortex-array/src/arrays/scalar_fn/rules.rs @@ -84,7 +84,8 @@ impl ArrayParentReduceRule for ScalarFnSliceReduceRule { .collect::>()?; Ok(Some( - ScalarFnArray::try_new(array.scalar_fn().clone(), children, range.len())?.into_array(), + ScalarFnArray::try_new_with_len(array.scalar_fn().clone(), children, range.len())? + .into_array(), )) } } @@ -142,7 +143,7 @@ impl ReduceCtx for ArrayReduceCtx { children: &[ReduceNodeRef], ) -> VortexResult { Ok(Arc::new( - ScalarFnArray::try_new( + ScalarFnArray::try_new_with_len( scalar_fn, children .iter() @@ -191,8 +192,7 @@ impl ArrayParentReduceRule for ScalarFnUnaryFilterPushDownRule { .try_collect()?; let new_array = - ScalarFnArray::try_new(child.scalar_fn().clone(), new_children, parent.len())? - .into_array(); + ScalarFnArray::try_new(child.scalar_fn().clone(), new_children)?.into_array(); return Ok(Some(new_array)); } diff --git a/vortex-array/src/arrays/scalar_fn/vtable/operations.rs b/vortex-array/src/arrays/scalar_fn/vtable/operations.rs index 1f710af3cdf..67de19d936f 100644 --- a/vortex-array/src/arrays/scalar_fn/vtable/operations.rs +++ b/vortex-array/src/arrays/scalar_fn/vtable/operations.rs @@ -65,9 +65,12 @@ mod tests { use crate::arrays::BoolArray; use crate::arrays::PrimitiveArray; use crate::arrays::ScalarFnArray; + use crate::arrays::scalar_fn::ScalarFnArrayExt; use crate::assert_arrays_eq; + use crate::scalar::Scalar; use crate::scalar_fn::TypedScalarFnInstance; use crate::scalar_fn::fns::binary::Binary; + use crate::scalar_fn::fns::literal::Literal; use crate::scalar_fn::fns::operators::Operator; use crate::validity::Validity; @@ -77,7 +80,9 @@ mod tests { let rhs = buffer![10i32, 20, 30].into_array(); let scalar_fn = TypedScalarFnInstance::new(Binary, Operator::Add).erased(); - let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 3)?; + let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs])?; + + assert_eq!(scalar_fn_array.len(), 3); let result = scalar_fn_array .into_array() @@ -89,13 +94,47 @@ mod tests { Ok(()) } + #[test] + fn test_scalar_fn_inferred_len_rejects_mismatched_children() { + let lhs = buffer![1i32, 2, 3].into_array(); + let rhs = buffer![10i32, 20].into_array(); + + let scalar_fn = TypedScalarFnInstance::new(Binary, Operator::Add).erased(); + let err = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs]) + .expect_err("ScalarFnArray::try_new must reject mismatched child lengths"); + + assert!( + err.to_string() + .contains("ScalarFnArray must have children equal to the array length") + ); + } + + #[test] + fn test_scalar_fn_without_children_requires_explicit_len() -> VortexResult<()> { + let scalar_fn = TypedScalarFnInstance::new(Literal, Scalar::from(1i32)).erased(); + + let Err(err) = ScalarFnArray::try_new(scalar_fn.clone(), vec![]) else { + panic!("ScalarFnArray::try_new should reject zero children"); + }; + assert!( + err.to_string() + .contains("ScalarFnArray length cannot be inferred without children") + ); + + let scalar_fn_array = ScalarFnArray::try_new_with_len(scalar_fn, vec![], 3)?; + assert_eq!(scalar_fn_array.len(), 3); + assert_eq!(scalar_fn_array.child_count(), 0); + + Ok(()) + } + #[test] fn test_scalar_fn_mul() -> VortexResult<()> { let lhs = buffer![2i32, 3, 4].into_array(); let rhs = buffer![5i32, 6, 7].into_array(); let scalar_fn = TypedScalarFnInstance::new(Binary, Operator::Mul).erased(); - let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 3)?; + let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs])?; let result = scalar_fn_array .into_array() @@ -117,7 +156,7 @@ mod tests { .into_array(); let scalar_fn = TypedScalarFnInstance::new(Binary, Operator::Add).erased(); - let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 3)?; + let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs])?; let result = scalar_fn_array .into_array() @@ -139,7 +178,7 @@ mod tests { let rhs = buffer![2i32, 5, 1].into_array(); let scalar_fn = TypedScalarFnInstance::new(Binary, Operator::Eq).erased(); - let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 3)?; + let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs])?; let result = scalar_fn_array .into_array() diff --git a/vortex-array/src/expression.rs b/vortex-array/src/expression.rs index 780bfcf9db1..d6b8739abf6 100644 --- a/vortex-array/src/expression.rs +++ b/vortex-array/src/expression.rs @@ -35,7 +35,8 @@ impl ArrayRef { // And wrap the scalar function up in an array. let array = - ScalarFnArray::try_new(expr.scalar_fn().clone(), children, self.len())?.into_array(); + ScalarFnArray::try_new_with_len(expr.scalar_fn().clone(), children, self.len())? + .into_array(); // Optimize the resulting array's root. array.optimize() diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index 2b267671c8c..82328a9b2d9 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -71,8 +71,8 @@ impl CosineSimilarity { /// /// Returns an error if the [`ScalarFnArray`] cannot be constructed (e.g. due to dtype /// mismatches). - pub fn try_new_array(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult { - ScalarFnArray::try_new(CosineSimilarity::new().erased(), vec![lhs, rhs], len) + pub fn try_new_array(lhs: ArrayRef, rhs: ArrayRef) -> VortexResult { + ScalarFnArray::try_new(CosineSimilarity::new().erased(), vec![lhs, rhs]) } } @@ -141,9 +141,9 @@ impl ScalarFnVTable for CosineSimilarity { let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?; // Compute inner product and norms as columnar operations, and propagate the options. - let norm_lhs_arr = L2Norm::try_new_array(lhs_ref.clone(), len)?; - let norm_rhs_arr = L2Norm::try_new_array(rhs_ref.clone(), len)?; - let dot_arr = InnerProduct::try_new_array(lhs_ref, rhs_ref, len)?; + let norm_lhs_arr = L2Norm::try_new_array(lhs_ref.clone())?; + let norm_rhs_arr = L2Norm::try_new_array(rhs_ref.clone())?; + let dot_arr = InnerProduct::try_new_array(lhs_ref, rhs_ref)?; // Execute to get the inner product and norms of the arrays. We only fully decompress // because we need to perform special logic (guard against 0) during division. @@ -239,7 +239,7 @@ impl CosineSimilarity { // `L2Denorm` makes the normalized children authoritative, so their dot product is the // cosine similarity even for lossy storage wrappers, except that a zero stored norm still // represents a zero vector. - let dot: PrimitiveArray = InnerProduct::try_new_array(normalized_l, normalized_r, len)? + let dot: PrimitiveArray = InnerProduct::try_new_array(normalized_l, normalized_r)? .into_array() .execute(ctx)?; let norms_l: PrimitiveArray = norms_l.execute(ctx)?; @@ -279,12 +279,12 @@ impl CosineSimilarity { let (normalized, denorm_norms) = extract_l2_denorm_children(denorm_ref); - let dot_arr = InnerProduct::try_new_array(normalized, plain_ref.clone(), len)?; + let dot_arr = InnerProduct::try_new_array(normalized, plain_ref.clone())?; let dot: PrimitiveArray = dot_arr.into_array().execute(ctx)?; let denorm_norms: PrimitiveArray = denorm_norms.execute(ctx)?; - let norm_arr = L2Norm::try_new_array(plain_ref.clone(), len)?; + let norm_arr = L2Norm::try_new_array(plain_ref.clone())?; let plain_norm: PrimitiveArray = norm_arr.into_array().execute(ctx)?; // TODO(connor): Ideally we would have a `SafeDiv` binary numeric operation. @@ -335,9 +335,9 @@ mod tests { use crate::utils::test_helpers::vector_array; /// Evaluates cosine similarity between two tensor arrays and returns the result as `Vec`. - fn eval_cosine_similarity(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult> { + fn eval_cosine_similarity(lhs: ArrayRef, rhs: ArrayRef) -> VortexResult> { let scalar_fn = CosineSimilarity::new().erased(); - let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], len)?; + let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs])?; let mut ctx = SESSION.create_execution_ctx(); let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?; Ok(prim.as_slice::().to_vec()) @@ -361,7 +361,7 @@ mod tests { )?; // Row 0: identical -> 1.0, row 1: orthogonal -> 0.0. - assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 0.0]); + assert_close(&eval_cosine_similarity(lhs, rhs)?, &[1.0, 0.0]); Ok(()) } @@ -381,7 +381,7 @@ mod tests { ) -> VortexResult<()> { let lhs = tensor_array(shape, lhs_elems)?; let rhs = tensor_array(shape, rhs_elems)?; - assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, expected); + assert_close(&eval_cosine_similarity(lhs, rhs)?, expected); Ok(()) } @@ -400,7 +400,7 @@ mod tests { fn self_similarity(#[case] shape: &[usize], #[case] elements: &[f64]) -> VortexResult<()> { let lhs = tensor_array(shape, elements)?; let rhs = tensor_array(shape, elements)?; - assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[1.0]); + assert_close(&eval_cosine_similarity(lhs, rhs)?, &[1.0]); Ok(()) } @@ -411,7 +411,7 @@ mod tests { let rhs = tensor_array(&[], &[5.0, -3.0])?; // Same sign -> 1.0, opposite sign -> -1.0. - assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, -1.0]); + assert_close(&eval_cosine_similarity(lhs, rhs)?, &[1.0, -1.0]); Ok(()) } @@ -431,7 +431,7 @@ mod tests { let rhs = lhs.clone(); assert_close( - &eval_cosine_similarity(lhs, rhs, 5)?, + &eval_cosine_similarity(lhs, rhs)?, &[1.0, 1.0, 1.0, 1.0, 1.0], ); Ok(()) @@ -451,10 +451,7 @@ mod tests { )?; let query = constant_tensor_array(&[3], &[1.0, 0.0, 0.0], 4)?; - assert_close( - &eval_cosine_similarity(data, query, 4)?, - &[1.0, 0.0, 0.0, 1.0], - ); + assert_close(&eval_cosine_similarity(data, query)?, &[1.0, 0.0, 0.0, 1.0]); Ok(()) } @@ -476,7 +473,7 @@ mod tests { )?; // Row 0: identical -> 1.0, row 1: orthogonal -> 0.0. - assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 0.0]); + assert_close(&eval_cosine_similarity(lhs, rhs)?, &[1.0, 0.0]); Ok(()) } @@ -493,10 +490,7 @@ mod tests { )?; let query = Vector::constant_array(&[1.0, 0.0, 0.0], 4)?; - assert_close( - &eval_cosine_similarity(data, query, 4)?, - &[1.0, 0.0, 0.0, 1.0], - ); + assert_close(&eval_cosine_similarity(data, query)?, &[1.0, 0.0, 0.0, 1.0]); Ok(()) } @@ -508,7 +502,7 @@ mod tests { let rhs = MaskedArray::try_new(rhs, Validity::from_iter([true, false]))?.into_array(); let scalar_fn = CosineSimilarity::new().erased(); - let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 2)?; + let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs])?; let mut ctx = SESSION.create_execution_ctx(); let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?; @@ -528,7 +522,7 @@ mod tests { let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; // Self-similarity should always be 1.0. - assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 1.0]); + assert_close(&eval_cosine_similarity(lhs, rhs)?, &[1.0, 1.0]); Ok(()) } @@ -540,7 +534,7 @@ mod tests { let lhs = l2_denorm_array(&[2], &[1.0, 0.0], &[3.0], &mut ctx)?; let rhs = l2_denorm_array(&[2], &[0.0, 1.0], &[4.0], &mut ctx)?; - assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[0.0]); + assert_close(&eval_cosine_similarity(lhs, rhs)?, &[0.0]); Ok(()) } @@ -552,7 +546,7 @@ mod tests { let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; // Row 0: dot([0.6, 0.8], [0.6, 0.8]) = 1.0, row 1: dot([0,0], [1,0]) = 0.0. - assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 0.0]); + assert_close(&eval_cosine_similarity(lhs, rhs)?, &[1.0, 0.0]); Ok(()) } @@ -565,7 +559,7 @@ mod tests { let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0], &mut ctx)?; let rhs = tensor_array(&[2], &[3.0, 4.0])?; - assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[1.0]); + assert_close(&eval_cosine_similarity(lhs, rhs)?, &[1.0]); Ok(()) } @@ -577,7 +571,7 @@ mod tests { let lhs = tensor_array(&[2], &[1.0, 0.0])?; let rhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0], &mut ctx)?; - assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[0.6]); + assert_close(&eval_cosine_similarity(lhs, rhs)?, &[0.6]); Ok(()) } @@ -589,10 +583,10 @@ mod tests { let normalized_r = tensor_array(&[2], &[0.6, 0.8, 1.0, 0.0])?; let norms_r = PrimitiveArray::from_option_iter([Some(5.0f64), None]).into_array(); - let rhs = L2Denorm::try_new_array(normalized_r, norms_r, 2, &mut ctx)?.into_array(); + let rhs = L2Denorm::try_new_array(normalized_r, norms_r, &mut ctx)?.into_array(); let scalar_fn = CosineSimilarity::new().erased(); - let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 2)?; + let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs])?; let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?; assert!(prim.is_valid(0, &mut ctx)?); @@ -611,16 +605,16 @@ mod tests { let norms_l = PrimitiveArray::from_iter([0.0f64]).into_array(); // SAFETY: This is a focused test that intentionally violates the unit-norm invariant by // pairing a nonzero normalized row with a stored norm of `0.0`, mimicking lossy storage. - let lhs = unsafe { L2Denorm::new_array_unchecked(normalized_l, norms_l, 1)? }.into_array(); + let lhs = unsafe { L2Denorm::new_array_unchecked(normalized_l, norms_l)? }.into_array(); let normalized_r = tensor_array(&[2], &[0.6, 0.8])?; let norms_r = PrimitiveArray::from_iter([0.0f64]).into_array(); // SAFETY: Same as above for the rhs operand. - let rhs = unsafe { L2Denorm::new_array_unchecked(normalized_r, norms_r, 1)? }.into_array(); + let rhs = unsafe { L2Denorm::new_array_unchecked(normalized_r, norms_r)? }.into_array(); // `dot(normalized_l, normalized_r) = 1.0`, but the authoritative stored norms are both // `0.0`, so cosine similarity must be `0.0`. - assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[0.0]); + assert_close(&eval_cosine_similarity(lhs, rhs)?, &[0.0]); Ok(()) } @@ -634,19 +628,19 @@ mod tests { let norms = PrimitiveArray::from_iter([0.0f64]).into_array(); // SAFETY: This is a focused test that intentionally pairs a nonzero normalized row with a // stored norm of `0.0`, mimicking lossy storage where the stored norm is authoritative. - let denorm = unsafe { L2Denorm::new_array_unchecked(normalized, norms, 1)? }.into_array(); + let denorm = unsafe { L2Denorm::new_array_unchecked(normalized, norms)? }.into_array(); let plain = tensor_array(&[2], &[1.0, 0.0])?; // Denorm on the lhs: `One { denorm: lhs, plain: rhs }`. assert_close( - &eval_cosine_similarity(denorm.clone(), plain.clone(), 1)?, + &eval_cosine_similarity(denorm.clone(), plain.clone())?, &[0.0], ); // Denorm on the rhs: `One { denorm: rhs, plain: lhs }`. The same zero-norm guard must // fire regardless of operand order. - assert_close(&eval_cosine_similarity(plain, denorm, 1)?, &[0.0]); + assert_close(&eval_cosine_similarity(plain, denorm)?, &[0.0]); Ok(()) } @@ -665,7 +659,7 @@ mod tests { ], )?; assert_close( - &eval_cosine_similarity(lhs, rhs, 4)?, + &eval_cosine_similarity(lhs, rhs)?, &[1.0 / 3.0, 1.0, 2.0 / 3.0, 8.0 / 9.0], ); Ok(()) @@ -685,7 +679,7 @@ mod tests { )?; let rhs = constant_tensor_array(&[3], &[1.0, 2.0, 2.0], 4)?; assert_close( - &eval_cosine_similarity(lhs, rhs, 4)?, + &eval_cosine_similarity(lhs, rhs)?, &[1.0 / 3.0, 1.0, 2.0 / 3.0, 8.0 / 9.0], ); Ok(()) @@ -698,7 +692,7 @@ mod tests { let rhs = constant_tensor_array(&[3], &[1.0, 1.0, 0.0], 3)?; let expected = 1.0 / 2.0_f64.sqrt(); assert_close( - &eval_cosine_similarity(lhs, rhs, 3)?, + &eval_cosine_similarity(lhs, rhs)?, &[expected, expected, expected], ); Ok(()) @@ -717,7 +711,7 @@ mod tests { 7.0, 8.0, 9.0, // ], )?; - assert_close(&eval_cosine_similarity(lhs, rhs, 3)?, &[0.0, 0.0, 0.0]); + assert_close(&eval_cosine_similarity(lhs, rhs)?, &[0.0, 0.0, 0.0]); Ok(()) } @@ -728,7 +722,7 @@ mod tests { // L2Denorm fast path's inner product yields 1. let lhs = constant_tensor_array(&[3], &[3.0, 4.0, 0.0], 5)?; let rhs = constant_tensor_array(&[3], &[3.0, 4.0, 0.0], 5)?; - assert_close(&eval_cosine_similarity(lhs, rhs, 5)?, &[1.0; 5]); + assert_close(&eval_cosine_similarity(lhs, rhs)?, &[1.0; 5]); Ok(()) } @@ -746,21 +740,17 @@ mod tests { ], )?; assert_close( - &eval_cosine_similarity(lhs, rhs, 4)?, + &eval_cosine_similarity(lhs, rhs)?, &[1.0 / 3.0, 1.0, 2.0 / 3.0, 8.0 / 9.0], ); Ok(()) } #[rstest] - #[case::vector(cosine_vector_lhs(), cosine_vector_rhs(), 2)] - #[case::fixed_shape_tensor(cosine_tensor_lhs(), cosine_tensor_rhs(), 2)] - fn serde_round_trip( - #[case] lhs: ArrayRef, - #[case] rhs: ArrayRef, - #[case] len: usize, - ) -> VortexResult<()> { - let original = CosineSimilarity::try_new_array(lhs.clone(), rhs.clone(), len)?.into_array(); + #[case::vector(cosine_vector_lhs(), cosine_vector_rhs())] + #[case::fixed_shape_tensor(cosine_tensor_lhs(), cosine_tensor_rhs())] + fn serde_round_trip(#[case] lhs: ArrayRef, #[case] rhs: ArrayRef) -> VortexResult<()> { + let original = CosineSimilarity::try_new_array(lhs.clone(), rhs.clone())?.into_array(); let plugin = ScalarFnArrayPlugin::new(CosineSimilarity); let metadata = plugin diff --git a/vortex-tensor/src/scalar_fns/inner_product.rs b/vortex-tensor/src/scalar_fns/inner_product.rs index 8f6b67ce11b..cfea956f314 100644 --- a/vortex-tensor/src/scalar_fns/inner_product.rs +++ b/vortex-tensor/src/scalar_fns/inner_product.rs @@ -68,8 +68,8 @@ impl InnerProduct { /// /// Returns an error if the [`ScalarFnArray`] cannot be constructed (e.g. due to dtype /// mismatches). - pub fn try_new_array(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult { - ScalarFnArray::try_new(InnerProduct::new().erased(), vec![lhs, rhs], len) + pub fn try_new_array(lhs: ArrayRef, rhs: ArrayRef) -> VortexResult { + ScalarFnArray::try_new(InnerProduct::new().erased(), vec![lhs, rhs]) } } @@ -222,7 +222,7 @@ impl InnerProduct { let norms_l: PrimitiveArray = norms_l.execute(ctx)?; let norms_r: PrimitiveArray = norms_r.execute(ctx)?; - let dot: PrimitiveArray = InnerProduct::try_new_array(normalized_l, normalized_r, len)? + let dot: PrimitiveArray = InnerProduct::try_new_array(normalized_l, normalized_r)? .into_array() .execute(ctx)?; @@ -252,7 +252,7 @@ impl InnerProduct { let (normalized, norms) = extract_l2_denorm_children(denorm_ref); let denorm_norms: PrimitiveArray = norms.execute(ctx)?; - let dot: PrimitiveArray = InnerProduct::try_new_array(normalized, plain_ref.clone(), len)? + let dot: PrimitiveArray = InnerProduct::try_new_array(normalized, plain_ref.clone())? .into_array() .execute(ctx)?; @@ -301,9 +301,9 @@ mod tests { use crate::utils::test_helpers::vector_array; /// Evaluates inner product between two tensor arrays and returns the result as `Vec`. - fn eval_inner_product(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult> { + fn eval_inner_product(lhs: ArrayRef, rhs: ArrayRef) -> VortexResult> { let scalar_fn = InnerProduct::new().erased(); - let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], len)?; + let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs])?; let mut ctx = SESSION.create_execution_ctx(); let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?; Ok(prim.as_slice::().to_vec()) @@ -327,7 +327,7 @@ mod tests { ) -> VortexResult<()> { let lhs = tensor_array(shape, lhs_elems)?; let rhs = tensor_array(shape, rhs_elems)?; - assert_close(&eval_inner_product(lhs, rhs, 1)?, expected); + assert_close(&eval_inner_product(lhs, rhs)?, expected); Ok(()) } @@ -349,7 +349,7 @@ mod tests { 2.0, 2.0, 2.0, // tensor 2: dot = 6 ], )?; - assert_close(&eval_inner_product(lhs, rhs, 3)?, &[0.0, 25.0, 6.0]); + assert_close(&eval_inner_product(lhs, rhs)?, &[0.0, 25.0, 6.0]); Ok(()) } @@ -369,7 +369,7 @@ mod tests { 0.0, 1.0, // vector 1: dot = 0 ], )?; - assert_close(&eval_inner_product(lhs, rhs, 2)?, &[25.0, 0.0]); + assert_close(&eval_inner_product(lhs, rhs)?, &[25.0, 0.0]); Ok(()) } @@ -381,7 +381,7 @@ mod tests { let lhs = MaskedArray::try_new(lhs, Validity::from_iter([true, false, true]))?.into_array(); let scalar_fn = InnerProduct::new().erased(); - let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 3)?; + let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs])?; let mut ctx = SESSION.create_execution_ctx(); let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?; @@ -398,7 +398,7 @@ mod tests { fn rejects_non_extension_dtype() { let lhs = PrimitiveArray::from_iter([1.0_f64, 2.0]).into_array(); let rhs = PrimitiveArray::from_iter([3.0_f64, 4.0]).into_array(); - let result = InnerProduct::try_new_array(lhs, rhs, 2); + let result = InnerProduct::try_new_array(lhs, rhs); assert!(result.is_err()); } @@ -406,7 +406,7 @@ mod tests { fn rejects_mismatched_dtypes() -> VortexResult<()> { let lhs = tensor_array(&[2], &[1.0_f64, 2.0])?; let rhs = vector_array(2, &[3.0_f64, 4.0])?; - let result = InnerProduct::try_new_array(lhs, rhs, 1); + let result = InnerProduct::try_new_array(lhs, rhs); assert!(result.is_err()); Ok(()) } @@ -421,7 +421,7 @@ mod tests { let rhs = l2_denorm_array(&[2], &[1.0, 0.0], &[1.0], &mut ctx)?; // Expected: 5.0 * 1.0 * dot([0.6, 0.8], [1.0, 0.0]) = 5.0 * 0.6 = 3.0. - assert_close(&eval_inner_product(lhs, rhs, 1)?, &[3.0]); + assert_close(&eval_inner_product(lhs, rhs)?, &[3.0]); Ok(()) } @@ -433,7 +433,7 @@ mod tests { let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 0.0, 1.0], &[5.0, 1.0], &mut ctx)?; - assert_close(&eval_inner_product(lhs, rhs, 2)?, &[25.0, 0.0]); + assert_close(&eval_inner_product(lhs, rhs)?, &[25.0, 0.0]); Ok(()) } @@ -446,7 +446,7 @@ mod tests { let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0], &mut ctx)?; let rhs = tensor_array(&[2], &[1.0, 2.0])?; - assert_close(&eval_inner_product(lhs, rhs, 1)?, &[11.0]); + assert_close(&eval_inner_product(lhs, rhs)?, &[11.0]); Ok(()) } @@ -459,7 +459,7 @@ mod tests { let lhs = tensor_array(&[2], &[1.0, 2.0])?; let rhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0], &mut ctx)?; - assert_close(&eval_inner_product(lhs, rhs, 1)?, &[11.0]); + assert_close(&eval_inner_product(lhs, rhs)?, &[11.0]); Ok(()) } @@ -470,11 +470,11 @@ mod tests { let norms_l = PrimitiveArray::from_option_iter([Some(5.0f64), None]).into_array(); let mut ctx = SESSION.create_execution_ctx(); - let lhs = L2Denorm::try_new_array(normalized_l, norms_l, 2, &mut ctx)?.into_array(); + let lhs = L2Denorm::try_new_array(normalized_l, norms_l, &mut ctx)?.into_array(); let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; let scalar_fn = InnerProduct::new().erased(); - let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 2)?; + let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs])?; let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?; // Row 0: 5.0 * 5.0 * dot([0.6, 0.8], [0.6, 0.8]) = 25.0, row 1: null. @@ -485,14 +485,10 @@ mod tests { } #[rstest] - #[case::vector(inner_product_vector_lhs(), inner_product_vector_rhs(), 2)] - #[case::fixed_shape_tensor(inner_product_tensor_lhs(), inner_product_tensor_rhs(), 2)] - fn serde_round_trip( - #[case] lhs: ArrayRef, - #[case] rhs: ArrayRef, - #[case] len: usize, - ) -> VortexResult<()> { - let original = InnerProduct::try_new_array(lhs.clone(), rhs.clone(), len)?.into_array(); + #[case::vector(inner_product_vector_lhs(), inner_product_vector_rhs())] + #[case::fixed_shape_tensor(inner_product_tensor_lhs(), inner_product_tensor_rhs())] + fn serde_round_trip(#[case] lhs: ArrayRef, #[case] rhs: ArrayRef) -> VortexResult<()> { + let original = InnerProduct::try_new_array(lhs.clone(), rhs.clone())?.into_array(); let plugin = ScalarFnArrayPlugin::new(InnerProduct); let metadata = plugin diff --git a/vortex-tensor/src/scalar_fns/l2_denorm.rs b/vortex-tensor/src/scalar_fns/l2_denorm.rs index d8ca5fc41ed..4f4159b4a21 100644 --- a/vortex-tensor/src/scalar_fns/l2_denorm.rs +++ b/vortex-tensor/src/scalar_fns/l2_denorm.rs @@ -111,13 +111,12 @@ impl L2Denorm { pub fn try_new_array( normalized: ArrayRef, norms: ArrayRef, - len: usize, ctx: &mut ExecutionCtx, ) -> VortexResult { validate_l2_normalized_rows_against_norms(&normalized, Some(&norms), ctx)?; // SAFETY: We just validated that it is normalized. - unsafe { Self::new_array_unchecked(normalized, norms, len) } + unsafe { Self::new_array_unchecked(normalized, norms) } } /// Constructs an [`L2Denorm`] array without validating that the `normalized` child is actually @@ -140,9 +139,8 @@ impl L2Denorm { pub unsafe fn new_array_unchecked( normalized: ArrayRef, norms: ArrayRef, - len: usize, ) -> VortexResult { - ScalarFnArray::try_new(L2Denorm::new().erased(), vec![normalized, norms], len) + ScalarFnArray::try_new(L2Denorm::new().erased(), vec![normalized, norms]) } } @@ -432,7 +430,7 @@ pub fn normalize_as_l2_denorm( } // Calculate the norms of the vectors. - let norms_sfn = L2Norm::try_new_array(input.clone(), row_count)?; + let norms_sfn = L2Norm::try_new_array(input.clone())?; let norms_array: ArrayRef = norms_sfn.into_array().execute(ctx)?; let primitive_norms: PrimitiveArray = norms_array.clone().execute(ctx)?; let norms_validity = primitive_norms.validity()?; @@ -487,7 +485,7 @@ pub fn normalize_as_l2_denorm( // construction. // - Null rows are zeroed out above to avoid propagating arbitrary physical storage values into // downstream lossy encodings. - unsafe { L2Denorm::new_array_unchecked(normalized, norms_array, row_count) } + unsafe { L2Denorm::new_array_unchecked(normalized, norms_array) } } /// Attempts to build an [`L2Denorm`] whose two children are both [`ConstantArray`]s by eagerly @@ -570,7 +568,7 @@ pub(crate) fn try_build_constant_l2_denorm( // point tolerance) or all zeros when `||v|| == 0`. Stored norms are non-negative by // construction (`sqrt`). These are exactly the invariants required by // [`L2Denorm::new_array_unchecked`]. - let wrapped = unsafe { L2Denorm::new_array_unchecked(normalized_ext, norms_array, len)? }; + let wrapped = unsafe { L2Denorm::new_array_unchecked(normalized_ext, norms_array)? }; Ok(Some(wrapped)) } @@ -771,9 +769,9 @@ mod tests { use crate::utils::test_helpers::vector_array; /// Evaluates L2 denorm on a tensor/vector array and returns the executed array. - fn eval_l2_denorm(normalized: ArrayRef, norms: ArrayRef, len: usize) -> VortexResult { + fn eval_l2_denorm(normalized: ArrayRef, norms: ArrayRef) -> VortexResult { let mut ctx = SESSION.create_execution_ctx(); - let result = L2Denorm::try_new_array(normalized, norms, len, &mut ctx)?; + let result = L2Denorm::try_new_array(normalized, norms, &mut ctx)?; result.into_array().execute(&mut ctx) } @@ -813,7 +811,7 @@ mod tests { fn l2_denorm_vectors() -> VortexResult<()> { let lhs = vector_array(3, &[0.6, 0.8, 0.0, 0.0, 0.0, 0.0])?; let rhs = PrimitiveArray::from_iter([5.0f64, 0.0]).into_array(); - let actual = eval_l2_denorm(lhs, rhs, 2)?; + let actual = eval_l2_denorm(lhs, rhs)?; let expected = vector_array(3, &[3.0, 4.0, 0.0, 0.0, 0.0, 0.0])?; assert_tensor_arrays_eq(actual, expected)?; @@ -824,7 +822,7 @@ mod tests { fn l2_denorm_fixed_shape_tensors() -> VortexResult<()> { let lhs = tensor_array(&[2, 2], &[0.5, 0.5, 0.5, 0.5, 1.0, 0.0, 0.0, 0.0])?; let rhs = PrimitiveArray::from_iter([4.0f64, 2.0]).into_array(); - let actual = eval_l2_denorm(lhs, rhs, 2)?; + let actual = eval_l2_denorm(lhs, rhs)?; let expected = tensor_array(&[2, 2], &[2.0, 2.0, 2.0, 2.0, 2.0, 0.0, 0.0, 0.0])?; assert_tensor_arrays_eq(actual, expected)?; @@ -838,7 +836,7 @@ mod tests { let rhs = PrimitiveArray::from_option_iter([Some(5.0f64), Some(2.0), None]).into_array(); let mut ctx = SESSION.create_execution_ctx(); - let actual: ExtensionArray = eval_l2_denorm(lhs, rhs, 3)?.execute(&mut ctx)?; + let actual: ExtensionArray = eval_l2_denorm(lhs, rhs)?.execute(&mut ctx)?; let storage: FixedSizeListArray = actual.storage_array().clone().execute(&mut ctx)?; let elements: PrimitiveArray = storage.elements().clone().execute(&mut ctx)?; @@ -855,7 +853,7 @@ mod tests { let rhs = PrimitiveArray::from_iter([1.0f64, 1.0]).into_array(); let mut ctx = SESSION.create_execution_ctx(); - let result = L2Denorm::try_new_array(lhs, rhs, 2, &mut ctx); + let result = L2Denorm::try_new_array(lhs, rhs, &mut ctx); assert!(result.is_err()); } @@ -865,7 +863,7 @@ mod tests { let rhs = PrimitiveArray::from_iter([1.0f64, 1.0]).into_array(); let mut ctx = SESSION.create_execution_ctx(); - let result = L2Denorm::try_new_array(lhs, rhs, 2, &mut ctx); + let result = L2Denorm::try_new_array(lhs, rhs, &mut ctx); assert!(result.is_err()); Ok(()) } @@ -876,7 +874,7 @@ mod tests { let rhs = PrimitiveArray::from_iter([1.0f64, 1.0]).into_array(); let mut ctx = SESSION.create_execution_ctx(); - let result = L2Denorm::try_new_array(lhs, rhs, 2, &mut ctx); + let result = L2Denorm::try_new_array(lhs, rhs, &mut ctx); assert!(result.is_err()); Ok(()) } @@ -887,7 +885,7 @@ mod tests { let rhs = PrimitiveArray::from_iter([1.0f32, 1.0]).into_array(); let mut ctx = SESSION.create_execution_ctx(); - let result = L2Denorm::try_new_array(lhs, rhs, 2, &mut ctx); + let result = L2Denorm::try_new_array(lhs, rhs, &mut ctx); assert!(result.is_err()); Ok(()) } @@ -916,7 +914,7 @@ mod tests { let norms = PrimitiveArray::from_iter([5.0f64, 1.0]).into_array(); let mut ctx = SESSION.create_execution_ctx(); - let result = L2Denorm::try_new_array(normalized, norms, 2, &mut ctx); + let result = L2Denorm::try_new_array(normalized, norms, &mut ctx); assert!(result.is_err()); Ok(()) } @@ -927,7 +925,7 @@ mod tests { let norms = PrimitiveArray::from_iter([0.0f64, 0.0]).into_array(); let mut ctx = SESSION.create_execution_ctx(); - let result = L2Denorm::try_new_array(normalized, norms, 2, &mut ctx); + let result = L2Denorm::try_new_array(normalized, norms, &mut ctx); assert!(result.is_err()); Ok(()) } @@ -938,7 +936,7 @@ mod tests { let norms = PrimitiveArray::from_iter([1.0f64, -1.0]).into_array(); let mut ctx = SESSION.create_execution_ctx(); - let result = L2Denorm::try_new_array(normalized, norms, 2, &mut ctx); + let result = L2Denorm::try_new_array(normalized, norms, &mut ctx); assert!(result.is_err()); Ok(()) } @@ -948,7 +946,7 @@ mod tests { let normalized = vector_array(2, &[3.0, 4.0, 1.0, 0.0])?; let norms = PrimitiveArray::from_iter([5.0f64, 1.0]).into_array(); - let result = unsafe { L2Denorm::new_array_unchecked(normalized, norms, 2) }; + let result = unsafe { L2Denorm::new_array_unchecked(normalized, norms) }; assert!(result.is_ok()); Ok(()) } @@ -1062,7 +1060,7 @@ mod tests { let normalized = vector_array(3, &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0])?; let norms = constant_f64_norms(1.0, 2); - let actual = eval_l2_denorm(normalized.clone(), norms, 2)?; + let actual = eval_l2_denorm(normalized.clone(), norms)?; assert_tensor_arrays_eq(actual, normalized)?; Ok(()) } @@ -1074,7 +1072,7 @@ mod tests { let normalized = vector_array(3, &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0])?; let norms = constant_f64_norms(1.0 + 1e-12, 2); - let actual = eval_l2_denorm(normalized.clone(), norms, 2)?; + let actual = eval_l2_denorm(normalized.clone(), norms)?; assert_tensor_arrays_eq(actual, normalized)?; Ok(()) } @@ -1086,7 +1084,7 @@ mod tests { let normalized = vector_array(3, &[0.6, 0.8, 0.0, 1.0, 0.0, 0.0])?; let norms = constant_f64_norms(5.0, 2); - let actual = eval_l2_denorm(normalized, norms, 2)?; + let actual = eval_l2_denorm(normalized, norms)?; let expected = vector_array(3, &[3.0, 4.0, 0.0, 5.0, 0.0, 0.0])?; assert_tensor_arrays_eq(actual, expected)?; Ok(()) @@ -1099,7 +1097,7 @@ mod tests { let normalized = tensor_array(&[2, 2], &[0.5, 0.5, 0.5, 0.5, 1.0, 0.0, 0.0, 0.0])?; let norms = constant_f64_norms(4.0, 2); - let actual = eval_l2_denorm(normalized, norms, 2)?; + let actual = eval_l2_denorm(normalized, norms)?; let expected = tensor_array(&[2, 2], &[2.0, 2.0, 2.0, 2.0, 4.0, 0.0, 0.0, 0.0])?; assert_tensor_arrays_eq(actual, expected)?; Ok(()) diff --git a/vortex-tensor/src/scalar_fns/l2_norm.rs b/vortex-tensor/src/scalar_fns/l2_norm.rs index 7e8a85cc500..aa85ab4c78e 100644 --- a/vortex-tensor/src/scalar_fns/l2_norm.rs +++ b/vortex-tensor/src/scalar_fns/l2_norm.rs @@ -75,8 +75,8 @@ impl L2Norm { /// /// Returns an error if the [`ScalarFnArray`] cannot be constructed (e.g. due to dtype /// mismatches). - pub fn try_new_array(child: ArrayRef, len: usize) -> VortexResult { - ScalarFnArray::try_new(L2Norm::new().erased(), vec![child], len) + pub fn try_new_array(child: ArrayRef) -> VortexResult { + ScalarFnArray::try_new(L2Norm::new().erased(), vec![child]) } } @@ -282,9 +282,9 @@ mod tests { use crate::utils::test_helpers::vector_array; /// Evaluates L2 norm on a tensor/vector array and returns the result as `Vec`. - fn eval_l2_norm(input: ArrayRef, len: usize) -> VortexResult> { + fn eval_l2_norm(input: ArrayRef) -> VortexResult> { let scalar_fn = L2Norm::new().erased(); - let result = ScalarFnArray::try_new(scalar_fn, vec![input], len)?; + let result = ScalarFnArray::try_new(scalar_fn, vec![input])?; let mut ctx = SESSION.create_execution_ctx(); let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?; Ok(prim.as_slice::().to_vec()) @@ -301,7 +301,7 @@ mod tests { #[case] expected: &[f64], ) -> VortexResult<()> { let arr = tensor_array(shape, elements)?; - assert_close(&eval_l2_norm(arr, 1)?, expected); + assert_close(&eval_l2_norm(arr)?, expected); Ok(()) } @@ -315,7 +315,7 @@ mod tests { 1.0, 1.0, 1.0, // norm = sqrt(3) ], )?; - assert_close(&eval_l2_norm(arr, 3)?, &[5.0, 0.0, 3.0_f64.sqrt()]); + assert_close(&eval_l2_norm(arr)?, &[5.0, 0.0, 3.0_f64.sqrt()]); Ok(()) } @@ -328,7 +328,7 @@ mod tests { 3.0, 4.0, 0.0, // norm = 5.0 ], )?; - assert_close(&eval_l2_norm(arr, 2)?, &[1.0, 5.0]); + assert_close(&eval_l2_norm(arr)?, &[1.0, 5.0]); Ok(()) } @@ -339,7 +339,7 @@ mod tests { let arr = MaskedArray::try_new(arr, Validity::from_iter([true, false]))?.into_array(); let scalar_fn = L2Norm::new().erased(); - let result = ScalarFnArray::try_new(scalar_fn, vec![arr], 2)?; + let result = ScalarFnArray::try_new(scalar_fn, vec![arr])?; let mut ctx = SESSION.create_execution_ctx(); let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?; @@ -359,7 +359,7 @@ mod tests { let input = literal_vector_array(&[3.0f64, 4.0], 4); let scalar_fn = L2Norm::new().erased(); - let result = ScalarFnArray::try_new(scalar_fn, vec![input], 4)?.into_array(); + let result = ScalarFnArray::try_new(scalar_fn, vec![input])?.into_array(); let mut ctx = SESSION.create_execution_ctx(); let output = result.execute_until::(&mut ctx)?; @@ -390,7 +390,7 @@ mod tests { let input = ConstantArray::new(null_scalar, 3).into_array(); let scalar_fn = L2Norm::new().erased(); - let result = ScalarFnArray::try_new(scalar_fn, vec![input], 3)?.into_array(); + let result = ScalarFnArray::try_new(scalar_fn, vec![input])?.into_array(); let mut ctx = SESSION.create_execution_ctx(); let output = result.execute_until::(&mut ctx)?; @@ -407,10 +407,10 @@ mod tests { } #[rstest] - #[case::fixed_shape_tensor(l2_norm_tensor_child(), 2)] - #[case::vector(l2_norm_vector_child(), 2)] - fn serde_round_trip(#[case] child: ArrayRef, #[case] len: usize) -> VortexResult<()> { - let original = L2Norm::try_new_array(child.clone(), len)?.into_array(); + #[case::fixed_shape_tensor(l2_norm_tensor_child())] + #[case::vector(l2_norm_vector_child())] + fn serde_round_trip(#[case] child: ArrayRef) -> VortexResult<()> { + let original = L2Norm::try_new_array(child.clone())?.into_array(); let plugin = ScalarFnArrayPlugin::new(L2Norm); let metadata = plugin diff --git a/vortex-tensor/src/utils.rs b/vortex-tensor/src/utils.rs index f9558a07373..0f9bc1b8653 100644 --- a/vortex-tensor/src/utils.rs +++ b/vortex-tensor/src/utils.rs @@ -372,11 +372,10 @@ pub mod test_helpers { norms: &[T], ctx: &mut ExecutionCtx, ) -> VortexResult { - let len = norms.len(); let normalized = tensor_array(shape, normalized_elements)?; let norms = PrimitiveArray::new(Buffer::copy_from(norms), Validity::NonNullable).into_array(); - Ok(L2Denorm::try_new_array(normalized, norms, len, ctx)?.into_array()) + Ok(L2Denorm::try_new_array(normalized, norms, ctx)?.into_array()) } /// Asserts that each element in `actual` is within `1e-10` of the corresponding `expected` diff --git a/vortex-tensor/src/vector_search.rs b/vortex-tensor/src/vector_search.rs index 2c053ccf0a2..ad3b96d1bff 100644 --- a/vortex-tensor/src/vector_search.rs +++ b/vortex-tensor/src/vector_search.rs @@ -79,7 +79,7 @@ pub fn build_similarity_search_tree>( let num_rows = data.len(); let query_vec = Vector::constant_array(query, num_rows)?; - let cosine = CosineSimilarity::try_new_array(data, query_vec, num_rows)?.into_array(); + let cosine = CosineSimilarity::try_new_array(data, query_vec)?.into_array(); let threshold_scalar = Scalar::primitive(threshold, Nullability::NonNullable); let threshold_array = ConstantArray::new(threshold_scalar, num_rows).into_array();