Skip to content

Commit 7a1d52d

Browse files
committed
Add VertexConstraintErrorFunctionT for vertex constraints
Summary: Adds VertexConstraintErrorFunctionT and its leaf classes (VertexPosition, VertexPlane, VertexProjection, VertexNormal) as new vertex-based error functions that inherit from SkeletonErrorFunctionT directly instead of through GeneralErrorFunctionT. Key changes: - VertexConstraintErrorFunctionT inherits SkeletonErrorFunctionT directly - Uses SkeletonDerivativeT for vertex derivative computation (skinning weights, blend shapes) - Supports parallel threading via dispenso - L2 fast-path optimization - VertexNormalConstraintErrorFunctionT handles normal rotation correction via SkeletonDerivativeT - Includes comprehensive vertex constraint leaf tests (48 new tests) Differential Revision: D94558395
1 parent 9831eb7 commit 7a1d52d

18 files changed

+3416
-329
lines changed

cmake/build_variables.bzl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,12 @@ character_solver_public_headers = [
266266
"character_solver/state_error_function.h",
267267
"character_solver/transform_pose.h",
268268
"character_solver/trust_region_qr.h",
269+
"character_solver/vertex_constraint_error_function.h",
269270
"character_solver/vertex_error_function.h",
271+
"character_solver/vertex_normal_constraint_error_function.h",
272+
"character_solver/vertex_plane_constraint_error_function.h",
273+
"character_solver/vertex_position_constraint_error_function.h",
274+
"character_solver/vertex_projection_constraint_error_function.h",
270275
"character_solver/vertex_projection_error_function.h",
271276
"character_solver/vertex_sdf_error_function.h",
272277
"character_solver/vertex_vertex_distance_error_function.h",
@@ -302,7 +307,12 @@ character_solver_sources = [
302307
"character_solver/state_error_function.cpp",
303308
"character_solver/transform_pose.cpp",
304309
"character_solver/trust_region_qr.cpp",
310+
"character_solver/vertex_constraint_error_function.cpp",
305311
"character_solver/vertex_error_function.cpp",
312+
"character_solver/vertex_normal_constraint_error_function.cpp",
313+
"character_solver/vertex_plane_constraint_error_function.cpp",
314+
"character_solver/vertex_position_constraint_error_function.cpp",
315+
"character_solver/vertex_projection_constraint_error_function.cpp",
306316
"character_solver/vertex_projection_error_function.cpp",
307317
"character_solver/vertex_sdf_error_function.cpp",
308318
"character_solver/vertex_vertex_distance_error_function.cpp",
@@ -315,6 +325,7 @@ character_solver_test_sources = [
315325
"test/character_solver/inverse_kinematics_test.cpp",
316326
"test/character_solver/skeleton_derivative_test.cpp",
317327
"test/character_solver/solver_test.cpp",
328+
"test/character_solver/vertex_constraint_leaves_test.cpp",
318329
]
319330

320331
simd_constraints_public_headers = [
Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#include "momentum/character_solver/vertex_constraint_error_function.h"
9+
10+
#include "momentum/character/character.h"
11+
#include "momentum/character/mesh_state.h"
12+
#include "momentum/character/skeleton.h"
13+
#include "momentum/common/checks.h"
14+
15+
#include <dispenso/parallel_for.h>
16+
17+
#include <numeric>
18+
19+
namespace momentum {
20+
21+
template <typename T, class Data, size_t FuncDim>
22+
VertexConstraintErrorFunctionT<T, Data, FuncDim>::VertexConstraintErrorFunctionT(
23+
const Character& character,
24+
const ParameterTransform& parameterTransform,
25+
const T& lossAlpha,
26+
const T& lossC)
27+
: SkeletonErrorFunctionT<T>(character.skeleton, parameterTransform),
28+
character_(character),
29+
loss_(lossAlpha, lossC) {}
30+
31+
template <typename T, class Data, size_t FuncDim>
32+
const Character* VertexConstraintErrorFunctionT<T, Data, FuncDim>::getCharacter() const {
33+
return &character_;
34+
}
35+
36+
template <typename T, class Data, size_t FuncDim>
37+
bool VertexConstraintErrorFunctionT<T, Data, FuncDim>::needsMesh() const {
38+
return true;
39+
}
40+
41+
template <typename T, class Data, size_t FuncDim>
42+
void VertexConstraintErrorFunctionT<T, Data, FuncDim>::addConstraint(const Data& constraint) {
43+
MT_CHECK(constraint.vertexIndex != kInvalidIndex, "Constraint must have a valid vertex index");
44+
constraints_.push_back(constraint);
45+
}
46+
47+
template <typename T, class Data, size_t FuncDim>
48+
void VertexConstraintErrorFunctionT<T, Data, FuncDim>::setConstraints(
49+
std::span<const Data> constraints) {
50+
constraints_.assign(constraints.begin(), constraints.end());
51+
}
52+
53+
template <typename T, class Data, size_t FuncDim>
54+
void VertexConstraintErrorFunctionT<T, Data, FuncDim>::clearConstraints() {
55+
constraints_.clear();
56+
}
57+
58+
template <typename T, class Data, size_t FuncDim>
59+
const std::vector<Data>& VertexConstraintErrorFunctionT<T, Data, FuncDim>::getConstraints() const {
60+
return constraints_;
61+
}
62+
63+
template <typename T, class Data, size_t FuncDim>
64+
double VertexConstraintErrorFunctionT<T, Data, FuncDim>::getError(
65+
const ModelParametersT<T>& /*params*/,
66+
const SkeletonStateT<T>& state,
67+
const MeshStateT<T>& meshState) {
68+
const size_t numConstraints = constraints_.size();
69+
if (numConstraints == 0) {
70+
return 0.0;
71+
}
72+
73+
double error = 0.0;
74+
for (size_t i = 0; i < numConstraints; ++i) {
75+
const T constrWeight = static_cast<T>(constraints_[i].weight);
76+
if (constrWeight == T(0)) {
77+
continue;
78+
}
79+
80+
const size_t vertexIndex = constraints_[i].vertexIndex;
81+
const Eigen::Vector3<T> worldVec =
82+
meshState.posedMesh_->vertices[vertexIndex].template cast<T>();
83+
84+
FuncType f;
85+
std::span<DfdvType> emptyDfdv;
86+
evalFunction(
87+
i, state, meshState, std::span<const Eigen::Vector3<T>>(&worldVec, 1), f, emptyDfdv);
88+
89+
const T sqrError = f.squaredNorm();
90+
error += constrWeight * loss_.value(sqrError);
91+
}
92+
return this->weight_ * legacyWeight_ * error;
93+
}
94+
95+
template <typename T, class Data, size_t FuncDim>
96+
double VertexConstraintErrorFunctionT<T, Data, FuncDim>::getGradient(
97+
const ModelParametersT<T>& /*params*/,
98+
const SkeletonStateT<T>& state,
99+
const MeshStateT<T>& meshState,
100+
Eigen::Ref<Eigen::VectorX<T>> gradient) {
101+
const size_t numConstraints = constraints_.size();
102+
if (numConstraints == 0) {
103+
return 0.0;
104+
}
105+
106+
auto dispensoOptions = dispenso::ParForOptions();
107+
dispensoOptions.maxThreads = static_cast<uint32_t>(maxThreads_);
108+
109+
const SkeletonDerivativeT<T> skeletonDerivative(
110+
this->skeleton_,
111+
this->parameterTransform_,
112+
this->activeJointParams_,
113+
this->enabledParameters_);
114+
115+
const bool isL2 = loss_.isL2();
116+
const T lossInvC2 = loss_.invC2();
117+
118+
std::vector<std::tuple<double, VectorX<T>>> errorGradThread;
119+
dispenso::parallel_for(
120+
errorGradThread,
121+
[&]() -> std::tuple<double, VectorX<T>> {
122+
return {0.0, VectorX<T>::Zero(this->parameterTransform_.numAllModelParameters())};
123+
},
124+
size_t(0),
125+
numConstraints,
126+
[&](std::tuple<double, VectorX<T>>& errorGradLocal, const size_t i) {
127+
double& errorLocal = std::get<0>(errorGradLocal);
128+
auto& gradLocal = std::get<1>(errorGradLocal);
129+
130+
const T constrWeight = static_cast<T>(constraints_[i].weight);
131+
if (constrWeight == T(0)) {
132+
return;
133+
}
134+
135+
const size_t vertexIndex = constraints_[i].vertexIndex;
136+
const Eigen::Vector3<T> worldVec =
137+
meshState.posedMesh_->vertices[vertexIndex].template cast<T>();
138+
139+
FuncType f;
140+
DfdvType dfdv;
141+
dfdv.setZero();
142+
143+
evalFunction(
144+
i,
145+
state,
146+
meshState,
147+
std::span<const Eigen::Vector3<T>>(&worldVec, 1),
148+
f,
149+
std::span<DfdvType>(&dfdv, 1));
150+
151+
if (f.isZero()) {
152+
return;
153+
}
154+
155+
const T sqrError = f.squaredNorm();
156+
const T w = constrWeight * this->weight_ * legacyWeight_;
157+
158+
FuncType weightedResidual;
159+
if (isL2) {
160+
errorLocal += w * sqrError * lossInvC2;
161+
weightedResidual = T(2) * w * lossInvC2 * f;
162+
} else {
163+
errorLocal += w * loss_.value(sqrError);
164+
weightedResidual = T(2) * w * loss_.deriv(sqrError) * f;
165+
}
166+
167+
skeletonDerivative.template accumulateVertexGradient<FuncDim>(
168+
vertexIndex, worldVec, dfdv, weightedResidual, state, meshState, character_, gradLocal);
169+
},
170+
dispensoOptions);
171+
172+
if (!errorGradThread.empty()) {
173+
errorGradThread[0] = std::accumulate(
174+
errorGradThread.begin() + 1,
175+
errorGradThread.end(),
176+
errorGradThread[0],
177+
[](const auto& a, const auto& b) -> std::tuple<double, VectorX<T>> {
178+
return {std::get<0>(a) + std::get<0>(b), std::get<1>(a) + std::get<1>(b)};
179+
});
180+
181+
gradient += std::get<1>(errorGradThread[0]);
182+
return std::get<0>(errorGradThread[0]);
183+
}
184+
185+
return 0.0;
186+
}
187+
188+
template <typename T, class Data, size_t FuncDim>
189+
double VertexConstraintErrorFunctionT<T, Data, FuncDim>::getJacobian(
190+
const ModelParametersT<T>& /*params*/,
191+
const SkeletonStateT<T>& state,
192+
const MeshStateT<T>& meshState,
193+
Eigen::Ref<Eigen::MatrixX<T>> jacobian,
194+
Eigen::Ref<Eigen::VectorX<T>> residual,
195+
int& usedRows) {
196+
const size_t numConstraints = constraints_.size();
197+
usedRows = 0;
198+
if (numConstraints == 0) {
199+
return 0.0;
200+
}
201+
202+
auto dispensoOptions = dispenso::ParForOptions();
203+
dispensoOptions.maxThreads = static_cast<uint32_t>(maxThreads_);
204+
205+
const SkeletonDerivativeT<T> skeletonDerivative(
206+
this->skeleton_,
207+
this->parameterTransform_,
208+
this->activeJointParams_,
209+
this->enabledParameters_);
210+
211+
const bool isL2Jac = loss_.isL2();
212+
const T lossInvC2Jac = loss_.invC2();
213+
214+
std::vector<double> errorThread;
215+
dispenso::parallel_for(
216+
errorThread,
217+
[&]() -> double { return 0.0; },
218+
size_t(0),
219+
numConstraints,
220+
[&](double& errorLocal, const size_t i) {
221+
const T constrWeight = static_cast<T>(constraints_[i].weight);
222+
if (constrWeight == T(0)) {
223+
return;
224+
}
225+
226+
const size_t vertexIndex = constraints_[i].vertexIndex;
227+
const Eigen::Vector3<T> worldVec =
228+
meshState.posedMesh_->vertices[vertexIndex].template cast<T>();
229+
230+
const size_t rowIndex = i * FuncDim;
231+
FuncType f;
232+
DfdvType dfdv;
233+
dfdv.setZero();
234+
235+
evalFunction(
236+
i,
237+
state,
238+
meshState,
239+
std::span<const Eigen::Vector3<T>>(&worldVec, 1),
240+
f,
241+
std::span<DfdvType>(&dfdv, 1));
242+
243+
const T sqrError = f.squaredNorm();
244+
const T w = constrWeight * this->weight_ * legacyWeight_;
245+
246+
T deriv;
247+
if (isL2Jac) {
248+
errorLocal += w * sqrError * lossInvC2Jac;
249+
deriv = std::sqrt(w * lossInvC2Jac);
250+
} else {
251+
errorLocal += w * loss_.value(sqrError);
252+
deriv = std::sqrt(w * loss_.deriv(sqrError));
253+
}
254+
255+
residual.template segment<FuncDim>(rowIndex).noalias() = deriv * f;
256+
257+
if (deriv == T(0)) {
258+
return;
259+
}
260+
261+
skeletonDerivative.template accumulateVertexJacobian<FuncDim>(
262+
vertexIndex, worldVec, dfdv, deriv, state, meshState, character_, jacobian, rowIndex);
263+
},
264+
dispensoOptions);
265+
266+
usedRows = static_cast<int>(numConstraints * FuncDim);
267+
268+
if (!errorThread.empty()) {
269+
return std::accumulate(errorThread.begin() + 1, errorThread.end(), errorThread[0]);
270+
}
271+
272+
return 0.0;
273+
}
274+
275+
// Explicit instantiations for common constraint types
276+
template class VertexConstraintErrorFunctionT<float, VertexConstraintData, 1>;
277+
template class VertexConstraintErrorFunctionT<float, VertexConstraintData, 2>;
278+
template class VertexConstraintErrorFunctionT<float, VertexConstraintData, 3>;
279+
template class VertexConstraintErrorFunctionT<double, VertexConstraintData, 1>;
280+
template class VertexConstraintErrorFunctionT<double, VertexConstraintData, 2>;
281+
template class VertexConstraintErrorFunctionT<double, VertexConstraintData, 3>;
282+
283+
} // namespace momentum
284+
285+
// Include leaf class headers for explicit instantiations
286+
#include "momentum/character_solver/vertex_normal_constraint_error_function.h"
287+
#include "momentum/character_solver/vertex_plane_constraint_error_function.h"
288+
#include "momentum/character_solver/vertex_position_constraint_error_function.h"
289+
#include "momentum/character_solver/vertex_projection_constraint_error_function.h"
290+
291+
namespace momentum {
292+
293+
template class VertexConstraintErrorFunctionT<float, VertexPositionConstraintDataT<float>, 3>;
294+
template class VertexConstraintErrorFunctionT<double, VertexPositionConstraintDataT<double>, 3>;
295+
296+
template class VertexConstraintErrorFunctionT<float, VertexNormalConstraintDataT<float>, 1>;
297+
template class VertexConstraintErrorFunctionT<double, VertexNormalConstraintDataT<double>, 1>;
298+
299+
template class VertexConstraintErrorFunctionT<float, VertexPlaneConstraintDataT<float>, 1>;
300+
template class VertexConstraintErrorFunctionT<double, VertexPlaneConstraintDataT<double>, 1>;
301+
302+
template class VertexConstraintErrorFunctionT<float, VertexProjectionConstraintDataT<float>, 2>;
303+
template class VertexConstraintErrorFunctionT<double, VertexProjectionConstraintDataT<double>, 2>;
304+
305+
} // namespace momentum

0 commit comments

Comments
 (0)