Skip to content

Commit fed60bd

Browse files
cdtwiggmeta-codesync[bot]
authored andcommitted
Migrate compute_vertex_normals to numpy arrays (#961)
Summary: Pull Request resolved: #961 Migrates `compute_vertex_normals` from torch tensors to numpy arrays as part of the ongoing effort to remove PyTorch dependencies from pymomentum.geometry. Key changes: - Refactored `VertexPositionsAccessor` to generic `VectorArrayAccessor<T, Dim>` with an inner `ElementView` class for per-element access - `ElementView` provides `get()`, `set()`, `add()`, `setZero()`, and `normalize()` methods for efficient element-level operations - Added `VectorArrayAccessor<int32_t, 3>` instantiation for triangle index access - Implemented `computeVertexNormalsArray()` using the new accessor pattern - Added bounds checking for triangle indices to provide clear error messages instead of undefined behavior - Updated test_geometry.py to use numpy arrays instead of torch tensors Reviewed By: jeongseok-meta Differential Revision: D89693867 fbshipit-source-id: c25e2d0a5eb87c6cd5a1b4a806fa79e5bfd09f92
1 parent 586d661 commit fed60bd

File tree

7 files changed

+683
-68
lines changed

7 files changed

+683
-68
lines changed

pymomentum/array_utility/geometry_accessors.cpp

Lines changed: 236 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -296,15 +296,15 @@ template class SkeletonStateAccessor<float>;
296296
template class SkeletonStateAccessor<double>;
297297

298298
// ============================================================================
299-
// VertexPositionsAccessor implementation
299+
// VectorArrayAccessor implementation
300300
// ============================================================================
301301

302-
template <typename T>
303-
VertexPositionsAccessor<T>::VertexPositionsAccessor(
302+
template <typename T, int Dim>
303+
VectorArrayAccessor<T, Dim>::VectorArrayAccessor(
304304
const py::buffer_info& bufferInfo,
305305
const LeadingDimensions& leadingDims,
306-
py::ssize_t nVertices)
307-
: nVertices_(nVertices), leadingNDim_(leadingDims.ndim()) {
306+
py::ssize_t nElements)
307+
: nElements_(nElements), leadingNDim_(leadingDims.ndim()) {
308308
data_ = static_cast<T*>(bufferInfo.ptr);
309309

310310
// Extract strides (convert from bytes to elements)
@@ -313,10 +313,14 @@ VertexPositionsAccessor<T>::VertexPositionsAccessor(
313313
for (int i = 0; i < totalNDim; ++i) {
314314
strides_[i] = static_cast<py::ssize_t>(bufferInfo.strides[i] / sizeof(T));
315315
}
316+
317+
// Cache the row and column strides for the trailing dimensions
318+
rowStride_ = strides_[leadingNDim_]; // stride for element dimension
319+
colStride_ = strides_[leadingNDim_ + 1]; // stride for vector component dimension
316320
}
317321

318-
template <typename T>
319-
py::ssize_t VertexPositionsAccessor<T>::computeOffset(
322+
template <typename T, int Dim>
323+
py::ssize_t VectorArrayAccessor<T, Dim>::computeOffset(
320324
const std::vector<py::ssize_t>& batchIndices) const {
321325
py::ssize_t offset = 0;
322326

@@ -329,52 +333,245 @@ py::ssize_t VertexPositionsAccessor<T>::computeOffset(
329333
return offset;
330334
}
331335

332-
template <typename T>
333-
std::vector<Eigen::Vector3<T>> VertexPositionsAccessor<T>::get(
334-
const std::vector<py::ssize_t>& batchIndices) const {
336+
template <typename T, int Dim>
337+
typename VectorArrayAccessor<T, Dim>::ElementView VectorArrayAccessor<T, Dim>::view(
338+
const std::vector<py::ssize_t>& batchIndices) {
335339
const auto offset = computeOffset(batchIndices);
340+
return ElementView(data_ + offset, nElements_, rowStride_, colStride_);
341+
}
336342

337-
// Create a vector of Vector3<T> using strides
338-
// Vertex positions format: (..., nVertices, 3) where each vertex has [x, y, z]
339-
std::vector<Eigen::Vector3<T>> positions(nVertices_);
343+
template <typename T, int Dim>
344+
typename VectorArrayAccessor<T, Dim>::ElementView VectorArrayAccessor<T, Dim>::view(
345+
const std::vector<py::ssize_t>& batchIndices) const {
346+
const auto offset = computeOffset(batchIndices);
347+
// const_cast is safe here because we control access at the call site
348+
return ElementView(const_cast<T*>(data_) + offset, nElements_, rowStride_, colStride_);
349+
}
340350

341-
const auto rowStride = strides_[leadingNDim_]; // stride for vertex dimension
342-
const auto colStride = strides_[leadingNDim_ + 1]; // stride for xyz dimension
351+
template <typename T, int Dim>
352+
std::vector<typename VectorArrayAccessor<T, Dim>::VectorType> VectorArrayAccessor<T, Dim>::get(
353+
const std::vector<py::ssize_t>& batchIndices) const {
354+
const auto offset = computeOffset(batchIndices);
343355

344-
for (py::ssize_t iVert = 0; iVert < nVertices_; ++iVert) {
345-
const auto vertOffset = offset + iVert * rowStride;
346-
positions[iVert].x() = data_[vertOffset + 0 * colStride];
347-
positions[iVert].y() = data_[vertOffset + 1 * colStride];
348-
positions[iVert].z() = data_[vertOffset + 2 * colStride];
356+
std::vector<VectorType> result(nElements_);
357+
for (py::ssize_t i = 0; i < nElements_; ++i) {
358+
const auto elemOffset = offset + i * rowStride_;
359+
for (int d = 0; d < Dim; ++d) {
360+
result[i][d] = data_[elemOffset + d * colStride_];
361+
}
349362
}
350363

351-
return positions;
364+
return result;
352365
}
353366

354-
template <typename T>
355-
void VertexPositionsAccessor<T>::set(
367+
template <typename T, int Dim>
368+
void VectorArrayAccessor<T, Dim>::set(
356369
const std::vector<py::ssize_t>& batchIndices,
357-
const std::vector<Eigen::Vector3<T>>& positions) {
370+
const std::vector<VectorType>& values) {
358371
MT_THROW_IF(
359-
static_cast<py::ssize_t>(positions.size()) != nVertices_,
360-
"set: expected {} vertices but got {}",
361-
nVertices_,
362-
positions.size());
372+
static_cast<py::ssize_t>(values.size()) != nElements_,
373+
"set: expected {} elements but got {}",
374+
nElements_,
375+
values.size());
363376

364377
const auto offset = computeOffset(batchIndices);
365-
const auto rowStride = strides_[leadingNDim_]; // stride for vertex dimension
366-
const auto colStride = strides_[leadingNDim_ + 1]; // stride for xyz dimension
367-
368-
for (py::ssize_t iVert = 0; iVert < nVertices_; ++iVert) {
369-
const auto vertOffset = offset + iVert * rowStride;
370-
data_[vertOffset + 0 * colStride] = positions[iVert].x();
371-
data_[vertOffset + 1 * colStride] = positions[iVert].y();
372-
data_[vertOffset + 2 * colStride] = positions[iVert].z();
378+
for (py::ssize_t i = 0; i < nElements_; ++i) {
379+
const auto elemOffset = offset + i * rowStride_;
380+
for (int d = 0; d < Dim; ++d) {
381+
data_[elemOffset + d * colStride_] = values[i][d];
382+
}
373383
}
374384
}
375385

376-
// Explicit template instantiations
377-
template class VertexPositionsAccessor<float>;
378-
template class VertexPositionsAccessor<double>;
386+
// Explicit template instantiations for VectorArrayAccessor
387+
template class VectorArrayAccessor<float, 3>;
388+
template class VectorArrayAccessor<double, 3>;
389+
template class VectorArrayAccessor<int, 3>;
390+
391+
// ============================================================================
392+
// IntVectorArrayAccessor implementation
393+
// ============================================================================
394+
395+
template <int Dim>
396+
IntVectorArrayAccessor<Dim>::IntVectorArrayAccessor(
397+
const py::buffer& buffer,
398+
const LeadingDimensions& leadingDims,
399+
py::ssize_t nElements)
400+
: nElements_(nElements), leadingNDim_(leadingDims.ndim()) {
401+
auto bufferInfo = buffer.request();
402+
data_ = bufferInfo.ptr;
403+
404+
// Detect the source dtype based on format code and itemsize.
405+
// Format codes 'l' (long) and 'L' (unsigned long) are platform-dependent:
406+
// - LP64 (Linux/macOS 64-bit): long is 64-bit
407+
// - LLP64 (Windows 64-bit): long is 32-bit
408+
// We use itemsize to disambiguate these cases.
409+
const auto& fmt = bufferInfo.format;
410+
const auto itemsize = bufferInfo.itemsize;
411+
412+
// Check for signed integer types
413+
const bool isSignedInt = fmt == py::format_descriptor<int32_t>::format() ||
414+
fmt == py::format_descriptor<int>::format() ||
415+
fmt == py::format_descriptor<int64_t>::format() ||
416+
fmt == "l"; // C long (size varies by platform)
417+
418+
// Check for unsigned integer types
419+
const bool isUnsignedInt = fmt == py::format_descriptor<uint32_t>::format() ||
420+
fmt == py::format_descriptor<uint64_t>::format() ||
421+
fmt == "L"; // C unsigned long (size varies by platform)
422+
423+
if (isSignedInt) {
424+
if (itemsize == 4) {
425+
dtype_ = SourceDtype::Int32;
426+
} else if (itemsize == 8) {
427+
dtype_ = SourceDtype::Int64;
428+
} else {
429+
MT_THROW(
430+
"IntVectorArrayAccessor: unexpected itemsize {} for signed integer format '{}'",
431+
itemsize,
432+
fmt);
433+
}
434+
} else if (isUnsignedInt) {
435+
if (itemsize == 4) {
436+
dtype_ = SourceDtype::UInt32;
437+
} else if (itemsize == 8) {
438+
dtype_ = SourceDtype::UInt64;
439+
} else {
440+
MT_THROW(
441+
"IntVectorArrayAccessor: unexpected itemsize {} for unsigned integer format '{}'",
442+
itemsize,
443+
fmt);
444+
}
445+
} else {
446+
MT_THROW(
447+
"IntVectorArrayAccessor: expected integer dtype (int32, int64, uint32, or uint64), got format '{}'",
448+
fmt);
449+
}
450+
451+
// Store byte strides (not element strides) - we'll convert in get()
452+
const auto totalNDim = bufferInfo.ndim;
453+
byteStrides_.resize(totalNDim);
454+
for (int i = 0; i < totalNDim; ++i) {
455+
byteStrides_[i] = bufferInfo.strides[i];
456+
}
457+
458+
// Cache the row and column byte strides for the trailing dimensions
459+
rowByteStride_ = byteStrides_[leadingNDim_];
460+
colByteStride_ = byteStrides_[leadingNDim_ + 1];
461+
}
462+
463+
template <int Dim>
464+
py::ssize_t IntVectorArrayAccessor<Dim>::computeByteOffset(
465+
const std::vector<py::ssize_t>& batchIndices) const {
466+
py::ssize_t offset = 0;
467+
468+
// Apply byte strides for each dimension
469+
// Broadcasting is automatically handled: if stride is 0, the index doesn't matter
470+
for (size_t i = 0; i < batchIndices.size(); ++i) {
471+
offset += batchIndices[i] * byteStrides_[i];
472+
}
473+
474+
return offset;
475+
}
476+
477+
template <int Dim>
478+
typename IntVectorArrayAccessor<Dim>::ElementView IntVectorArrayAccessor<Dim>::view(
479+
const std::vector<py::ssize_t>& batchIndices) const {
480+
const auto byteOffset = computeByteOffset(batchIndices);
481+
const auto* offsetData = static_cast<const char*>(data_) + byteOffset;
482+
return ElementView(offsetData, nElements_, rowByteStride_, colByteStride_, dtype_);
483+
}
484+
485+
template <int Dim>
486+
typename IntVectorArrayAccessor<Dim>::VectorType IntVectorArrayAccessor<Dim>::ElementView::get(
487+
py::ssize_t index) const {
488+
VectorType result;
489+
490+
switch (dtype_) {
491+
case SourceDtype::Int32: {
492+
const auto elemStride = rowStride_ / static_cast<py::ssize_t>(sizeof(int32_t));
493+
const auto compStride = colStride_ / static_cast<py::ssize_t>(sizeof(int32_t));
494+
const auto* ptr = static_cast<const int32_t*>(data_);
495+
for (int d = 0; d < Dim; ++d) {
496+
result[d] = static_cast<int>(ptr[index * elemStride + d * compStride]);
497+
}
498+
break;
499+
}
500+
case SourceDtype::Int64: {
501+
const auto elemStride = rowStride_ / static_cast<py::ssize_t>(sizeof(int64_t));
502+
const auto compStride = colStride_ / static_cast<py::ssize_t>(sizeof(int64_t));
503+
const auto* ptr = static_cast<const int64_t*>(data_);
504+
for (int d = 0; d < Dim; ++d) {
505+
result[d] = static_cast<int>(ptr[index * elemStride + d * compStride]);
506+
}
507+
break;
508+
}
509+
case SourceDtype::UInt32: {
510+
const auto elemStride = rowStride_ / static_cast<py::ssize_t>(sizeof(uint32_t));
511+
const auto compStride = colStride_ / static_cast<py::ssize_t>(sizeof(uint32_t));
512+
const auto* ptr = static_cast<const uint32_t*>(data_);
513+
for (int d = 0; d < Dim; ++d) {
514+
result[d] = static_cast<int>(ptr[index * elemStride + d * compStride]);
515+
}
516+
break;
517+
}
518+
case SourceDtype::UInt64: {
519+
const auto elemStride = rowStride_ / static_cast<py::ssize_t>(sizeof(uint64_t));
520+
const auto compStride = colStride_ / static_cast<py::ssize_t>(sizeof(uint64_t));
521+
const auto* ptr = static_cast<const uint64_t*>(data_);
522+
for (int d = 0; d < Dim; ++d) {
523+
result[d] = static_cast<int>(ptr[index * elemStride + d * compStride]);
524+
}
525+
break;
526+
}
527+
}
528+
529+
return result;
530+
}
531+
532+
template <int Dim>
533+
std::pair<int, int> IntVectorArrayAccessor<Dim>::ElementView::minmax() const {
534+
if (nElements_ == 0) {
535+
return {INT_MAX, INT_MIN};
536+
}
537+
538+
int minVal = INT_MAX;
539+
int maxVal = INT_MIN;
540+
541+
// Helper lambda to compute min/max for a given source type
542+
auto computeMinMax = [&]<typename SourceT>() {
543+
const auto elemStride = rowStride_ / static_cast<py::ssize_t>(sizeof(SourceT));
544+
const auto compStride = colStride_ / static_cast<py::ssize_t>(sizeof(SourceT));
545+
const auto* ptr = static_cast<const SourceT*>(data_);
546+
547+
for (py::ssize_t i = 0; i < nElements_; ++i) {
548+
for (int d = 0; d < Dim; ++d) {
549+
const int val = static_cast<int>(ptr[i * elemStride + d * compStride]);
550+
minVal = std::min(minVal, val);
551+
maxVal = std::max(maxVal, val);
552+
}
553+
}
554+
};
555+
556+
switch (dtype_) {
557+
case SourceDtype::Int32:
558+
computeMinMax.template operator()<int32_t>();
559+
break;
560+
case SourceDtype::Int64:
561+
computeMinMax.template operator()<int64_t>();
562+
break;
563+
case SourceDtype::UInt32:
564+
computeMinMax.template operator()<uint32_t>();
565+
break;
566+
case SourceDtype::UInt64:
567+
computeMinMax.template operator()<uint64_t>();
568+
break;
569+
}
570+
571+
return {minVal, maxVal};
572+
}
573+
574+
// Explicit template instantiations for IntVectorArrayAccessor
575+
template class IntVectorArrayAccessor<3>;
379576

380577
} // namespace pymomentum

0 commit comments

Comments
 (0)