-
Notifications
You must be signed in to change notification settings - Fork 35
Expand file tree
/
Copy patharray_skeleton_state.cpp
More file actions
311 lines (247 loc) · 11.6 KB
/
array_skeleton_state.cpp
File metadata and controls
311 lines (247 loc) · 11.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include "pymomentum/geometry/array_skeleton_state.h"
#include <pymomentum/array_utility/array_utility.h>
#include <pymomentum/array_utility/batch_accessor.h>
#include <pymomentum/array_utility/geometry_accessors.h>
#include <pymomentum/geometry/array_parameter_transform.h>
#include <momentum/character/joint.h>
#include <momentum/character/skeleton.h>
#include <momentum/character/skeleton_state.h>
#include <momentum/math/utility.h>
#include <dispenso/parallel_for.h>
#include <utility>
namespace pymomentum {
namespace {
template <typename T>
py::array_t<T> jointParametersToSkeletonStateImpl(
const momentum::Character& character,
const py::array& jointParams,
const LeadingDimensions& leadingDims,
JointParamsShape shape) {
const auto nJoints = static_cast<py::ssize_t>(character.skeleton.joints.size());
const auto nBatch = leadingDims.totalBatchElements();
// Create output array with shape [..., nJoints, 8]
auto result = createOutputArray<T>(leadingDims, {nJoints, 8});
// Create batch indexer for converting flat indices to multi-dimensional indices
BatchIndexer indexer(leadingDims);
// Create accessors using geometry accessors
JointParametersAccessor<T> inputAcc(jointParams, leadingDims, nJoints, shape);
SkeletonStateAccessor<T> outputAcc(result, leadingDims, nJoints);
// Release GIL for parallel computation
{
py::gil_scoped_release release;
dispenso::parallel_for(0, static_cast<int64_t>(nBatch), [&](int64_t iBatch) {
// Convert flat batch index to multi-dimensional indices
auto indices = indexer.decompose(iBatch);
// Get joint parameters as JointParametersT<T>
auto jp = inputAcc.get(indices);
// Compute skeleton state (global transforms)
const momentum::SkeletonStateT<T> skelState(jp, character.skeleton, /*computeDeriv=*/false);
// Extract transforms from skeleton state using JointState::transform
momentum::TransformListT<T> transforms(nJoints);
for (int64_t iJoint = 0; iJoint < nJoints; ++iJoint) {
transforms[iJoint] = skelState.jointState[iJoint].transform;
}
// Write transforms to output
outputAcc.setTransforms(indices, transforms);
});
}
return result;
}
template <typename T>
py::array_t<T> jointParametersToLocalSkeletonStateImpl(
const momentum::Character& character,
const py::array& jointParams,
const LeadingDimensions& leadingDims,
JointParamsShape shape) {
const auto nJoints = static_cast<py::ssize_t>(character.skeleton.joints.size());
const auto nBatch = leadingDims.totalBatchElements();
// Create output array with shape [..., nJoints, 8]
auto result = createOutputArray<T>(leadingDims, {nJoints, 8});
// Create batch indexer for converting flat indices to multi-dimensional indices
BatchIndexer indexer(leadingDims);
// Create accessors using geometry accessors
JointParametersAccessor<T> inputAcc(jointParams, leadingDims, nJoints, shape);
SkeletonStateAccessor<T> outputAcc(result, leadingDims, nJoints);
// Release GIL for parallel computation
{
py::gil_scoped_release release;
dispenso::parallel_for(0, static_cast<int64_t>(nBatch), [&](int64_t iBatch) {
// Convert flat batch index to multi-dimensional indices
auto indices = indexer.decompose(iBatch);
// Get joint parameters as JointParametersT<T>
auto jp = inputAcc.get(indices);
// Compute skeleton state (includes local transforms)
const momentum::SkeletonStateT<T> skelState(jp, character.skeleton, /*computeDeriv=*/false);
// Extract local transforms from skeleton state using JointState::localTransform
momentum::TransformListT<T> transforms(nJoints);
for (int64_t iJoint = 0; iJoint < nJoints; ++iJoint) {
transforms[iJoint] = skelState.jointState[iJoint].localTransform;
}
// Write transforms to output
outputAcc.setTransforms(indices, transforms);
});
}
return result;
}
template <typename T>
py::array_t<T> skeletonStateToJointParametersImpl(
const momentum::Character& character,
const py::array& skeletonState,
const LeadingDimensions& leadingDims) {
const auto nJoints = static_cast<py::ssize_t>(character.skeleton.joints.size());
const auto nBatch = leadingDims.totalBatchElements();
// Create output array with shape [..., nJoints, 7] (structured format)
auto result = createOutputArray<T>(leadingDims, {nJoints, 7});
// Create batch indexer for converting flat indices to multi-dimensional indices
BatchIndexer indexer(leadingDims);
// Create accessors using geometry accessors
SkeletonStateAccessor<T> inputAcc(skeletonState, leadingDims, nJoints);
JointParametersAccessor<T> outputAcc(result, leadingDims, nJoints, JointParamsShape::Structured);
// Release GIL for parallel computation
{
py::gil_scoped_release release;
dispenso::parallel_for(0, static_cast<int64_t>(nBatch), [&](int64_t iBatch) {
// Convert flat batch index to multi-dimensional indices
auto indices = indexer.decompose(iBatch);
// Get transforms from skeleton state
auto transforms = inputAcc.getTransforms(indices);
// Use momentum's skeletonStateToJointParameters function
auto jp = momentum::skeletonStateToJointParameters<T>(transforms, character.skeleton);
// Write joint parameters to output
outputAcc.set(indices, jp);
});
}
return result;
}
template <typename T>
py::array_t<T> localSkeletonStateToJointParametersImpl(
const momentum::Character& character,
const py::array& localSkeletonState,
const LeadingDimensions& leadingDims) {
const auto nJoints = static_cast<py::ssize_t>(character.skeleton.joints.size());
const auto nBatch = leadingDims.totalBatchElements();
// Create output array with shape [..., nJoints, 7] (structured format)
auto result = createOutputArray<T>(leadingDims, {nJoints, 7});
// Create batch indexer for converting flat indices to multi-dimensional indices
BatchIndexer indexer(leadingDims);
// Create accessors using geometry accessors
SkeletonStateAccessor<T> inputAcc(localSkeletonState, leadingDims, nJoints);
JointParametersAccessor<T> outputAcc(result, leadingDims, nJoints, JointParamsShape::Structured);
// Release GIL for parallel computation
{
py::gil_scoped_release release;
dispenso::parallel_for(0, static_cast<int64_t>(nBatch), [&](int64_t iBatch) {
// Convert flat batch index to multi-dimensional indices
auto indices = indexer.decompose(iBatch);
// Get local transforms from skeleton state
auto localTransforms = inputAcc.getTransforms(indices);
// Reconstruct joint parameters from local transforms
// Each joint has 7 parameters: tx, ty, tz, rx, ry, rz, scale
Eigen::Matrix<T, Eigen::Dynamic, 1> jpVec(nJoints * 7);
for (int64_t iJoint = 0; iJoint < nJoints; ++iJoint) {
const auto& joint = character.skeleton.joints[iJoint];
const auto& localTrans = localTransforms[iJoint];
// Translation offset: localTrans.translation = joint.translationOffset + params[0:3]
Eigen::Vector3<T> transPart =
localTrans.translation - joint.translationOffset.template cast<T>();
jpVec(iJoint * 7 + 0) = transPart.x();
jpVec(iJoint * 7 + 1) = transPart.y();
jpVec(iJoint * 7 + 2) = transPart.z();
// Rotation: localRot = preRotation * Rz(rz) * Ry(ry) * Rx(rx)
// We need to extract rx, ry, rz (Euler angles)
Eigen::Quaternion<T> preRotInv = joint.preRotation.template cast<T>().conjugate();
Eigen::Quaternion<T> pureRot = preRotInv * localTrans.rotation;
// Convert quaternion to Euler angles using momentum's utility (Extrinsic XYZ convention)
Eigen::Matrix3<T> rotMat = pureRot.toRotationMatrix();
Eigen::Vector3<T> euler =
momentum::rotationMatrixToEulerXYZ(rotMat, momentum::EulerConvention::Extrinsic);
jpVec(iJoint * 7 + 3) = euler(0); // rx
jpVec(iJoint * 7 + 4) = euler(1); // ry
jpVec(iJoint * 7 + 5) = euler(2); // rz
// Scale: scale = 2^params[6], so params[6] = log2(scale)
jpVec(iJoint * 7 + 6) = std::log2(localTrans.scale);
}
// Write joint parameters to output
outputAcc.set(indices, momentum::JointParametersT<T>(jpVec));
});
}
return result;
}
} // namespace
py::array jointParametersToSkeletonStateArray(
const momentum::Character& character,
const py::buffer& jointParams) {
ArrayChecker checker("joint_parameters_to_skeleton_state");
JointParamsShape shape =
checker.validateJointParameters(jointParams, "joint_parameters", character);
if (checker.isFloat64()) {
return jointParametersToSkeletonStateImpl<double>(
character, jointParams, checker.getLeadingDimensions(), shape);
} else {
return jointParametersToSkeletonStateImpl<float>(
character, jointParams, checker.getLeadingDimensions(), shape);
}
}
py::array modelParametersToSkeletonStateArray(
const momentum::Character& character,
const py::buffer& modelParams) {
// First apply parameter transform to get joint parameters
py::array jointParams = applyParameterTransformArray(character.parameterTransform, modelParams);
// Then convert joint parameters to skeleton state
return jointParametersToSkeletonStateArray(character, jointParams);
}
py::array jointParametersToLocalSkeletonStateArray(
const momentum::Character& character,
const py::buffer& jointParams) {
ArrayChecker checker("joint_parameters_to_local_skeleton_state");
JointParamsShape shape =
checker.validateJointParameters(jointParams, "joint_parameters", character);
if (checker.isFloat64()) {
return jointParametersToLocalSkeletonStateImpl<double>(
character, jointParams, checker.getLeadingDimensions(), shape);
} else {
return jointParametersToLocalSkeletonStateImpl<float>(
character, jointParams, checker.getLeadingDimensions(), shape);
}
}
py::array modelParametersToLocalSkeletonStateArray(
const momentum::Character& character,
const py::buffer& modelParams) {
// First apply parameter transform to get joint parameters
py::array jointParams = applyParameterTransformArray(character.parameterTransform, modelParams);
// Then convert joint parameters to local skeleton state
return jointParametersToLocalSkeletonStateArray(character, jointParams);
}
py::array skeletonStateToJointParametersArray(
const momentum::Character& character,
const py::buffer& skeletonState) {
ArrayChecker checker("skeleton_state_to_joint_parameters");
checker.validateSkeletonState(skeletonState, "skeleton_state", character);
if (checker.isFloat64()) {
return skeletonStateToJointParametersImpl<double>(
character, skeletonState, checker.getLeadingDimensions());
} else {
return skeletonStateToJointParametersImpl<float>(
character, skeletonState, checker.getLeadingDimensions());
}
}
py::array localSkeletonStateToJointParametersArray(
const momentum::Character& character,
const py::buffer& localSkeletonState) {
ArrayChecker checker("local_skeleton_state_to_joint_parameters");
checker.validateSkeletonState(localSkeletonState, "local_skeleton_state", character);
if (checker.isFloat64()) {
return localSkeletonStateToJointParametersImpl<double>(
character, localSkeletonState, checker.getLeadingDimensions());
} else {
return localSkeletonStateToJointParametersImpl<float>(
character, localSkeletonState, checker.getLeadingDimensions());
}
}
} // namespace pymomentum