@@ -296,15 +296,15 @@ template class SkeletonStateAccessor<float>;
296296template class SkeletonStateAccessor <double >;
297297
298298// ============================================================================
299- // VertexPositionsAccessor implementation
299+ // VectorArrayAccessor implementation
300300// ============================================================================
301301
302- template <typename T>
303- VertexPositionsAccessor<T >::VertexPositionsAccessor (
302+ template <typename T, int Dim >
303+ VectorArrayAccessor<T, Dim >::VectorArrayAccessor (
304304 const py::buffer_info& bufferInfo,
305305 const LeadingDimensions& leadingDims,
306- py::ssize_t nVertices )
307- : nVertices_(nVertices ), leadingNDim_(leadingDims.ndim()) {
306+ py::ssize_t nElements )
307+ : nElements_(nElements ), leadingNDim_(leadingDims.ndim()) {
308308 data_ = static_cast <T*>(bufferInfo.ptr );
309309
310310 // Extract strides (convert from bytes to elements)
@@ -313,10 +313,14 @@ VertexPositionsAccessor<T>::VertexPositionsAccessor(
313313 for (int i = 0 ; i < totalNDim; ++i) {
314314 strides_[i] = static_cast <py::ssize_t >(bufferInfo.strides [i] / sizeof (T));
315315 }
316+
317+ // Cache the row and column strides for the trailing dimensions
318+ rowStride_ = strides_[leadingNDim_]; // stride for element dimension
319+ colStride_ = strides_[leadingNDim_ + 1 ]; // stride for vector component dimension
316320}
317321
318- template <typename T>
319- py::ssize_t VertexPositionsAccessor<T >::computeOffset(
322+ template <typename T, int Dim >
323+ py::ssize_t VectorArrayAccessor<T, Dim >::computeOffset(
320324 const std::vector<py::ssize_t >& batchIndices) const {
321325 py::ssize_t offset = 0 ;
322326
@@ -329,52 +333,245 @@ py::ssize_t VertexPositionsAccessor<T>::computeOffset(
329333 return offset;
330334}
331335
332- template <typename T>
333- std::vector<Eigen::Vector3<T>> VertexPositionsAccessor<T >::get (
334- const std::vector<py::ssize_t >& batchIndices) const {
336+ template <typename T, int Dim >
337+ typename VectorArrayAccessor<T, Dim>::ElementView VectorArrayAccessor<T, Dim >::view (
338+ const std::vector<py::ssize_t >& batchIndices) {
335339 const auto offset = computeOffset (batchIndices);
340+ return ElementView (data_ + offset, nElements_, rowStride_, colStride_);
341+ }
336342
337- // Create a vector of Vector3<T> using strides
338- // Vertex positions format: (..., nVertices, 3) where each vertex has [x, y, z]
339- std::vector<Eigen::Vector3<T>> positions (nVertices_);
343+ template <typename T, int Dim>
344+ typename VectorArrayAccessor<T, Dim>::ElementView VectorArrayAccessor<T, Dim>::view(
345+ const std::vector<py::ssize_t >& batchIndices) const {
346+ const auto offset = computeOffset (batchIndices);
347+ // const_cast is safe here because we control access at the call site
348+ return ElementView (const_cast <T*>(data_) + offset, nElements_, rowStride_, colStride_);
349+ }
340350
341- const auto rowStride = strides_[leadingNDim_]; // stride for vertex dimension
342- const auto colStride = strides_[leadingNDim_ + 1 ]; // stride for xyz dimension
351+ template <typename T, int Dim>
352+ std::vector<typename VectorArrayAccessor<T, Dim>::VectorType> VectorArrayAccessor<T, Dim>::get(
353+ const std::vector<py::ssize_t >& batchIndices) const {
354+ const auto offset = computeOffset (batchIndices);
343355
344- for (py::ssize_t iVert = 0 ; iVert < nVertices_; ++iVert) {
345- const auto vertOffset = offset + iVert * rowStride;
346- positions[iVert].x () = data_[vertOffset + 0 * colStride];
347- positions[iVert].y () = data_[vertOffset + 1 * colStride];
348- positions[iVert].z () = data_[vertOffset + 2 * colStride];
356+ std::vector<VectorType> result (nElements_);
357+ for (py::ssize_t i = 0 ; i < nElements_; ++i) {
358+ const auto elemOffset = offset + i * rowStride_;
359+ for (int d = 0 ; d < Dim; ++d) {
360+ result[i][d] = data_[elemOffset + d * colStride_];
361+ }
349362 }
350363
351- return positions ;
364+ return result ;
352365}
353366
354- template <typename T>
355- void VertexPositionsAccessor<T >::set(
367+ template <typename T, int Dim >
368+ void VectorArrayAccessor<T, Dim >::set(
356369 const std::vector<py::ssize_t >& batchIndices,
357- const std::vector<Eigen::Vector3<T>>& positions ) {
370+ const std::vector<VectorType>& values ) {
358371 MT_THROW_IF (
359- static_cast <py::ssize_t >(positions .size ()) != nVertices_ ,
360- " set: expected {} vertices but got {}" ,
361- nVertices_ ,
362- positions .size ());
372+ static_cast <py::ssize_t >(values .size ()) != nElements_ ,
373+ " set: expected {} elements but got {}" ,
374+ nElements_ ,
375+ values .size ());
363376
364377 const auto offset = computeOffset (batchIndices);
365- const auto rowStride = strides_[leadingNDim_]; // stride for vertex dimension
366- const auto colStride = strides_[leadingNDim_ + 1 ]; // stride for xyz dimension
367-
368- for (py::ssize_t iVert = 0 ; iVert < nVertices_; ++iVert) {
369- const auto vertOffset = offset + iVert * rowStride;
370- data_[vertOffset + 0 * colStride] = positions[iVert].x ();
371- data_[vertOffset + 1 * colStride] = positions[iVert].y ();
372- data_[vertOffset + 2 * colStride] = positions[iVert].z ();
378+ for (py::ssize_t i = 0 ; i < nElements_; ++i) {
379+ const auto elemOffset = offset + i * rowStride_;
380+ for (int d = 0 ; d < Dim; ++d) {
381+ data_[elemOffset + d * colStride_] = values[i][d];
382+ }
373383 }
374384}
375385
376- // Explicit template instantiations
377- template class VertexPositionsAccessor <float >;
378- template class VertexPositionsAccessor <double >;
386+ // Explicit template instantiations for VectorArrayAccessor
387+ template class VectorArrayAccessor <float , 3 >;
388+ template class VectorArrayAccessor <double , 3 >;
389+ template class VectorArrayAccessor <int , 3 >;
390+
391+ // ============================================================================
392+ // IntVectorArrayAccessor implementation
393+ // ============================================================================
394+
395+ template <int Dim>
396+ IntVectorArrayAccessor<Dim>::IntVectorArrayAccessor(
397+ const py::buffer& buffer,
398+ const LeadingDimensions& leadingDims,
399+ py::ssize_t nElements)
400+ : nElements_(nElements), leadingNDim_(leadingDims.ndim()) {
401+ auto bufferInfo = buffer.request ();
402+ data_ = bufferInfo.ptr ;
403+
404+ // Detect the source dtype based on format code and itemsize.
405+ // Format codes 'l' (long) and 'L' (unsigned long) are platform-dependent:
406+ // - LP64 (Linux/macOS 64-bit): long is 64-bit
407+ // - LLP64 (Windows 64-bit): long is 32-bit
408+ // We use itemsize to disambiguate these cases.
409+ const auto & fmt = bufferInfo.format ;
410+ const auto itemsize = bufferInfo.itemsize ;
411+
412+ // Check for signed integer types
413+ const bool isSignedInt = fmt == py::format_descriptor<int32_t >::format () ||
414+ fmt == py::format_descriptor<int >::format () ||
415+ fmt == py::format_descriptor<int64_t >::format () ||
416+ fmt == " l" ; // C long (size varies by platform)
417+
418+ // Check for unsigned integer types
419+ const bool isUnsignedInt = fmt == py::format_descriptor<uint32_t >::format () ||
420+ fmt == py::format_descriptor<uint64_t >::format () ||
421+ fmt == " L" ; // C unsigned long (size varies by platform)
422+
423+ if (isSignedInt) {
424+ if (itemsize == 4 ) {
425+ dtype_ = SourceDtype::Int32;
426+ } else if (itemsize == 8 ) {
427+ dtype_ = SourceDtype::Int64;
428+ } else {
429+ MT_THROW (
430+ " IntVectorArrayAccessor: unexpected itemsize {} for signed integer format '{}'" ,
431+ itemsize,
432+ fmt);
433+ }
434+ } else if (isUnsignedInt) {
435+ if (itemsize == 4 ) {
436+ dtype_ = SourceDtype::UInt32;
437+ } else if (itemsize == 8 ) {
438+ dtype_ = SourceDtype::UInt64;
439+ } else {
440+ MT_THROW (
441+ " IntVectorArrayAccessor: unexpected itemsize {} for unsigned integer format '{}'" ,
442+ itemsize,
443+ fmt);
444+ }
445+ } else {
446+ MT_THROW (
447+ " IntVectorArrayAccessor: expected integer dtype (int32, int64, uint32, or uint64), got format '{}'" ,
448+ fmt);
449+ }
450+
451+ // Store byte strides (not element strides) - we'll convert in get()
452+ const auto totalNDim = bufferInfo.ndim ;
453+ byteStrides_.resize (totalNDim);
454+ for (int i = 0 ; i < totalNDim; ++i) {
455+ byteStrides_[i] = bufferInfo.strides [i];
456+ }
457+
458+ // Cache the row and column byte strides for the trailing dimensions
459+ rowByteStride_ = byteStrides_[leadingNDim_];
460+ colByteStride_ = byteStrides_[leadingNDim_ + 1 ];
461+ }
462+
463+ template <int Dim>
464+ py::ssize_t IntVectorArrayAccessor<Dim>::computeByteOffset(
465+ const std::vector<py::ssize_t >& batchIndices) const {
466+ py::ssize_t offset = 0 ;
467+
468+ // Apply byte strides for each dimension
469+ // Broadcasting is automatically handled: if stride is 0, the index doesn't matter
470+ for (size_t i = 0 ; i < batchIndices.size (); ++i) {
471+ offset += batchIndices[i] * byteStrides_[i];
472+ }
473+
474+ return offset;
475+ }
476+
477+ template <int Dim>
478+ typename IntVectorArrayAccessor<Dim>::ElementView IntVectorArrayAccessor<Dim>::view(
479+ const std::vector<py::ssize_t >& batchIndices) const {
480+ const auto byteOffset = computeByteOffset (batchIndices);
481+ const auto * offsetData = static_cast <const char *>(data_) + byteOffset;
482+ return ElementView (offsetData, nElements_, rowByteStride_, colByteStride_, dtype_);
483+ }
484+
485+ template <int Dim>
486+ typename IntVectorArrayAccessor<Dim>::VectorType IntVectorArrayAccessor<Dim>::ElementView::get(
487+ py::ssize_t index) const {
488+ VectorType result;
489+
490+ switch (dtype_) {
491+ case SourceDtype::Int32: {
492+ const auto elemStride = rowStride_ / static_cast <py::ssize_t >(sizeof (int32_t ));
493+ const auto compStride = colStride_ / static_cast <py::ssize_t >(sizeof (int32_t ));
494+ const auto * ptr = static_cast <const int32_t *>(data_);
495+ for (int d = 0 ; d < Dim; ++d) {
496+ result[d] = static_cast <int >(ptr[index * elemStride + d * compStride]);
497+ }
498+ break ;
499+ }
500+ case SourceDtype::Int64: {
501+ const auto elemStride = rowStride_ / static_cast <py::ssize_t >(sizeof (int64_t ));
502+ const auto compStride = colStride_ / static_cast <py::ssize_t >(sizeof (int64_t ));
503+ const auto * ptr = static_cast <const int64_t *>(data_);
504+ for (int d = 0 ; d < Dim; ++d) {
505+ result[d] = static_cast <int >(ptr[index * elemStride + d * compStride]);
506+ }
507+ break ;
508+ }
509+ case SourceDtype::UInt32: {
510+ const auto elemStride = rowStride_ / static_cast <py::ssize_t >(sizeof (uint32_t ));
511+ const auto compStride = colStride_ / static_cast <py::ssize_t >(sizeof (uint32_t ));
512+ const auto * ptr = static_cast <const uint32_t *>(data_);
513+ for (int d = 0 ; d < Dim; ++d) {
514+ result[d] = static_cast <int >(ptr[index * elemStride + d * compStride]);
515+ }
516+ break ;
517+ }
518+ case SourceDtype::UInt64: {
519+ const auto elemStride = rowStride_ / static_cast <py::ssize_t >(sizeof (uint64_t ));
520+ const auto compStride = colStride_ / static_cast <py::ssize_t >(sizeof (uint64_t ));
521+ const auto * ptr = static_cast <const uint64_t *>(data_);
522+ for (int d = 0 ; d < Dim; ++d) {
523+ result[d] = static_cast <int >(ptr[index * elemStride + d * compStride]);
524+ }
525+ break ;
526+ }
527+ }
528+
529+ return result;
530+ }
531+
532+ template <int Dim>
533+ std::pair<int , int > IntVectorArrayAccessor<Dim>::ElementView::minmax() const {
534+ if (nElements_ == 0 ) {
535+ return {INT_MAX, INT_MIN};
536+ }
537+
538+ int minVal = INT_MAX;
539+ int maxVal = INT_MIN;
540+
541+ // Helper lambda to compute min/max for a given source type
542+ auto computeMinMax = [&]<typename SourceT>() {
543+ const auto elemStride = rowStride_ / static_cast <py::ssize_t >(sizeof (SourceT));
544+ const auto compStride = colStride_ / static_cast <py::ssize_t >(sizeof (SourceT));
545+ const auto * ptr = static_cast <const SourceT*>(data_);
546+
547+ for (py::ssize_t i = 0 ; i < nElements_; ++i) {
548+ for (int d = 0 ; d < Dim; ++d) {
549+ const int val = static_cast <int >(ptr[i * elemStride + d * compStride]);
550+ minVal = std::min (minVal, val);
551+ maxVal = std::max (maxVal, val);
552+ }
553+ }
554+ };
555+
556+ switch (dtype_) {
557+ case SourceDtype::Int32:
558+ computeMinMax.template operator ()<int32_t >();
559+ break ;
560+ case SourceDtype::Int64:
561+ computeMinMax.template operator ()<int64_t >();
562+ break ;
563+ case SourceDtype::UInt32:
564+ computeMinMax.template operator ()<uint32_t >();
565+ break ;
566+ case SourceDtype::UInt64:
567+ computeMinMax.template operator ()<uint64_t >();
568+ break ;
569+ }
570+
571+ return {minVal, maxVal};
572+ }
573+
574+ // Explicit template instantiations for IntVectorArrayAccessor
575+ template class IntVectorArrayAccessor <3 >;
379576
380577} // namespace pymomentum
0 commit comments