Skip to content

Commit 7abf7cc

Browse files
cstollmetafacebook-github-bot
authored andcommitted
Split solver2_error_functions.cpp into multiple files to reduce LOC
Summary: This diff refactors `solver2_error_functions.cpp` (previously 2929 lines) by splitting error function registrations into separate files grouped by category: 1. **solver2_aim_axis_error_functions.cpp** (~365 lines): AimData, AimDistErrorFunction, AimDirErrorFunction, FixedAxisData, FixedAxisDiffErrorFunction, FixedAxisCosErrorFunction, FixedAxisAngleErrorFunction 2. **solver2_projection_error_functions.cpp** (~566 lines): ProjectionConstraint, ProjectionErrorFunction, VertexProjectionConstraint, VertexProjectionErrorFunction, CameraProjectionConstraint, CameraProjectionErrorFunction 3. **solver2_distance_error_functions.cpp** (~1051 lines): NormalData, NormalErrorFunction, DistanceData, DistanceErrorFunction, PlaneData, PlaneErrorFunction, VertexVertexDistanceConstraint, VertexVertexDistanceErrorFunction, JointToJointDistanceConstraint, JointToJointDistanceErrorFunction, JointToJointPositionData, JointToJointPositionErrorFunction 4. **solver2_error_functions.cpp** (reduced to ~1058 lines): LimitErrorFunction, HeightErrorFunction, ModelParametersErrorFunction, PositionErrorFunction, StateErrorFunction, VertexErrorFunction, PointTriangleVertexErrorFunction, PosePriorErrorFunction, OrientationErrorFunction, CollisionErrorFunction, VertexSDFErrorFunction The main file now calls sub-registration functions to maintain the same pybind11 module interface. Reviewed By: cdtwigg Differential Revision: D95088885
1 parent b9def7b commit 7abf7cc

File tree

6 files changed

+2049
-1972
lines changed

6 files changed

+2049
-1972
lines changed

pymomentum/cmake/build_variables.bzl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,11 @@ solver2_public_headers = [
194194
]
195195

196196
solver2_sources = [
197+
"solver2/solver2_aim_axis_error_functions.cpp",
197198
"solver2/solver2_camera_intrinsics.cpp",
199+
"solver2/solver2_distance_error_functions.cpp",
198200
"solver2/solver2_error_functions.cpp",
201+
"solver2/solver2_projection_error_functions.cpp",
199202
"solver2/solver2_pybind.cpp",
200203
"solver2/solver2_sequence_error_functions.cpp",
201204
"solver2/solver2_utility.cpp",
Lines changed: 366 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,366 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#include <pymomentum/solver2/solver2_error_functions.h>
9+
10+
#include <momentum/character/character.h>
11+
#include <momentum/character_solver/aim_error_function.h>
12+
#include <momentum/character_solver/fixed_axis_error_function.h>
13+
#include <pymomentum/solver2/solver2_utility.h>
14+
15+
#include <fmt/format.h>
16+
#include <pybind11/eigen.h>
17+
#include <pybind11/pybind11.h>
18+
#include <pybind11/stl.h>
19+
#include <Eigen/Core>
20+
21+
namespace py = pybind11;
22+
namespace mm = momentum;
23+
24+
namespace pymomentum {
25+
26+
namespace {
27+
28+
template <typename AimErrorFunctionT>
29+
void defAimErrorFunction(py::module_& m, const char* name, const char* description) {
30+
py::class_<AimErrorFunctionT, mm::SkeletonErrorFunction, std::shared_ptr<AimErrorFunctionT>>(
31+
m, name, description)
32+
.def(
33+
"__repr__",
34+
[=](const AimErrorFunctionT& self) -> std::string {
35+
return fmt::format(
36+
"{}(weight={}, num_constraints={})", name, self.getWeight(), self.numConstraints());
37+
})
38+
.def(
39+
py::init<>(
40+
[](const mm::Character& character, float lossAlpha, float lossC, float weight)
41+
-> std::shared_ptr<AimErrorFunctionT> {
42+
validateWeight(weight, "weight");
43+
auto result = std::make_shared<AimErrorFunctionT>(character, lossAlpha, lossC);
44+
result->setWeight(weight);
45+
return result;
46+
}),
47+
R"(Initialize error function.
48+
49+
:param character: The character to use.
50+
:param alpha: P-norm to use; 2 is Euclidean distance, 1 is L1 norm, 0 is Cauchy norm.
51+
:param c: c parameter in the generalized loss function, this is roughly equivalent to normalizing by the standard deviation in a Gaussian distribution.
52+
:param weight: The weight applied to the error function.)",
53+
py::keep_alive<1, 2>(),
54+
py::arg("character"),
55+
py::kw_only(),
56+
py::arg("alpha") = 2.0f,
57+
py::arg("c") = 1.0f,
58+
py::arg("weight") = 1.0f)
59+
.def_property_readonly(
60+
"constraints",
61+
[](const AimErrorFunctionT& self) { return self.getConstraints(); },
62+
"Returns the list of aim constraints.")
63+
.def(
64+
"add_constraint",
65+
[](AimErrorFunctionT& self,
66+
const Eigen::Vector3f& localPoint,
67+
const Eigen::Vector3f& localDir,
68+
const Eigen::Vector3f& globalTarget,
69+
int parent,
70+
float weight,
71+
const std::string& name) {
72+
validateJointIndex(parent, "parent", self.getSkeleton());
73+
validateWeight(weight, "weight");
74+
self.addConstraint(
75+
mm::AimDataT<float>(localPoint, localDir, globalTarget, parent, weight, name));
76+
},
77+
R"(Adds an aim constraint to the error function.
78+
79+
:param local_point: The origin of the local ray.
80+
:param local_dir: The direction of the local ray.
81+
:param global_target: The global aim target.
82+
:param parent_index: The index of the parent joint.
83+
:param weight: The weight of the constraint.
84+
:param name: The name of the constraint (for debugging).)",
85+
py::arg("local_point"),
86+
py::arg("local_dir"),
87+
py::arg("global_target"),
88+
py::arg("parent"),
89+
py::arg("weight") = 1.0f,
90+
py::arg("name") = std::string{})
91+
.def(
92+
"add_constraints",
93+
[](AimErrorFunctionT& self,
94+
const py::array_t<float>& localPoint,
95+
const py::array_t<float>& localDir,
96+
const py::array_t<float>& globalTarget,
97+
const py::array_t<int>& parent,
98+
const std::optional<py::array_t<float>>& weight,
99+
const std::optional<std::vector<std::string>>& name) {
100+
ArrayShapeValidator arrayValidator;
101+
const int nConsIdx = -1;
102+
arrayValidator.validate(localPoint, "local_point", {nConsIdx, 3}, {"n_cons", "xyz"});
103+
arrayValidator.validate(localDir, "local_dir", {nConsIdx, 3}, {"n_cons", "xyz"});
104+
arrayValidator.validate(
105+
globalTarget, "global_target", {nConsIdx, 3}, {"n_cons", "xyz"});
106+
arrayValidator.validate(parent, "parent", {nConsIdx}, {"n_cons"});
107+
validateJointIndex(parent, "parent", self.getSkeleton());
108+
arrayValidator.validate(weight, "weight", {nConsIdx}, {"n_cons"});
109+
110+
if (name.has_value() && name->size() != parent.shape(0)) {
111+
throw std::runtime_error(
112+
fmt::format(
113+
"Invalid names; expected {} names but got {}",
114+
parent.shape(0),
115+
name->size()));
116+
}
117+
118+
py::gil_scoped_release release;
119+
120+
auto localPointAcc = localPoint.unchecked<2>();
121+
auto localDirAcc = localDir.unchecked<2>();
122+
auto globalTargetAcc = globalTarget.unchecked<2>();
123+
auto parentAcc = parent.unchecked<1>();
124+
auto weightAcc =
125+
weight.has_value() ? std::make_optional(weight->unchecked<1>()) : std::nullopt;
126+
127+
for (py::ssize_t i = 0; i < localPoint.shape(0); ++i) {
128+
validateJointIndex(parentAcc(i), "parent", self.getSkeleton());
129+
self.addConstraint(
130+
mm::AimDataT<float>(
131+
Eigen::Vector3f(
132+
localPointAcc(i, 0), localPointAcc(i, 1), localPointAcc(i, 2)),
133+
Eigen::Vector3f(localDirAcc(i, 0), localDirAcc(i, 1), localDirAcc(i, 2)),
134+
Eigen::Vector3f(
135+
globalTargetAcc(i, 0), globalTargetAcc(i, 1), globalTargetAcc(i, 2)),
136+
parentAcc(i),
137+
weightAcc.has_value() ? (*weightAcc)(i) : 1.0f,
138+
name.has_value() ? name->at(i) : std::string{}));
139+
}
140+
},
141+
R"(Adds multiple aim constraints to the error function.
142+
143+
:param local_point: A numpy array of shape (n, 3) for the origins of the local rays.
144+
:param local_dir: A numpy array of shape (n, 3) for the directions of the local rays.
145+
:param global_target: A numpy array of shape (n, 3) for the global aim targets.
146+
:param parent_index: A numpy array of size n for the indices of the parent joints.
147+
:param weight: A numpy array of size n for the weights of the constraints.
148+
:param name: An optional list of names for the constraints (for debugging).)",
149+
py::arg("local_point"),
150+
py::arg("local_dir"),
151+
py::arg("global_target"),
152+
py::arg("parent_index"),
153+
py::arg("weight"),
154+
py::arg("name") = std::optional<std::vector<std::string>>{})
155+
.def(
156+
"clear_constraints",
157+
&AimErrorFunctionT::clearConstraints,
158+
"Clears all aim constraints from the error function.");
159+
}
160+
161+
template <typename FixedAxisErrorFunctionT>
162+
void defFixedAxisError(py::module_& m, const char* name, const char* description) {
163+
py::class_<
164+
FixedAxisErrorFunctionT,
165+
mm::SkeletonErrorFunctionT<float>,
166+
std::shared_ptr<FixedAxisErrorFunctionT>>(m, name, description)
167+
.def(
168+
"__repr__",
169+
[=](const FixedAxisErrorFunctionT& self) -> std::string {
170+
return fmt::format(
171+
"{}(weight={}, num_constraints={})", name, self.getWeight(), self.numConstraints());
172+
})
173+
.def(
174+
py::init<>(
175+
[](const mm::Character& character, float lossAlpha, float lossC, float weight)
176+
-> std::shared_ptr<FixedAxisErrorFunctionT> {
177+
validateWeight(weight, "weight");
178+
auto result =
179+
std::make_shared<FixedAxisErrorFunctionT>(character, lossAlpha, lossC);
180+
result->setWeight(weight);
181+
return result;
182+
}),
183+
R"(Initialize a FixedAxisDiffErrorFunction.
184+
185+
:param character: The character to use.
186+
:param alpha: P-norm to use; 2 is Euclidean distance, 1 is L1 norm, 0 is Cauchy norm.
187+
:param c: c parameter in the generalized loss function, this is roughly equivalent to normalizing by the standard deviation in a Gaussian distribution.
188+
:param weight: The weight applied to the error function.)",
189+
py::keep_alive<1, 2>(),
190+
py::arg("character"),
191+
py::kw_only(),
192+
py::arg("alpha") = 2.0f,
193+
py::arg("c") = 1.0f,
194+
py::arg("weight") = 1.0f)
195+
.def(
196+
"add_constraint",
197+
[](FixedAxisErrorFunctionT& self,
198+
const Eigen::Vector3f& localAxis,
199+
const Eigen::Vector3f& globalAxis,
200+
int parent,
201+
float weight,
202+
const std::string& name) {
203+
validateJointIndex(parent, "parent", self.getSkeleton());
204+
validateWeight(weight, "weight");
205+
self.addConstraint(
206+
mm::FixedAxisDataT<float>(localAxis, globalAxis, parent, weight, name));
207+
},
208+
R"(Adds a fixed axis constraint to the error function.
209+
210+
:param local_axis: The axis in the parent's coordinate frame.
211+
:param global_axis: The target axis in the global frame.
212+
:param parent: The index of the parent joint.
213+
:param weight: The weight of the constraint.
214+
:param name: The name of the constraint (for debugging).)",
215+
py::arg("local_axis"),
216+
py::arg("global_axis"),
217+
py::arg("parent"),
218+
py::arg("weight") = 1.0f,
219+
py::arg("name") = std::string{})
220+
.def_property_readonly(
221+
"constraints",
222+
[](const FixedAxisErrorFunctionT& self) { return self.getConstraints(); },
223+
"Returns the list of fixed axis constraints.")
224+
.def(
225+
"add_constraints",
226+
[](FixedAxisErrorFunctionT& self,
227+
const py::array_t<float>& localAxis,
228+
const py::array_t<float>& globalAxis,
229+
const py::array_t<int>& parent,
230+
const std::optional<py::array_t<float>>& weight,
231+
const std::optional<std::vector<std::string>>& name) {
232+
ArrayShapeValidator validator;
233+
const int nConsIdx = -1;
234+
validator.validate(localAxis, "local_axis", {nConsIdx, 3}, {"n_cons", "xyz"});
235+
validator.validate(globalAxis, "global_axis", {nConsIdx, 3}, {"n_cons", "xyz"});
236+
validator.validate(parent, "parent", {nConsIdx}, {"n_cons"});
237+
validateJointIndex(parent, "parent", self.getSkeleton());
238+
validator.validate(weight, "weight", {nConsIdx}, {"n_cons"});
239+
240+
if (name.has_value() && name->size() != parent.shape(0)) {
241+
throw std::runtime_error(
242+
fmt::format(
243+
"Invalid names; expected {} names but got {}",
244+
parent.shape(0),
245+
name->size()));
246+
}
247+
248+
auto localAxisAcc = localAxis.unchecked<2>();
249+
auto globalAxisAcc = globalAxis.unchecked<2>();
250+
auto parentAcc = parent.unchecked<1>();
251+
auto weightAcc =
252+
weight.has_value() ? std::make_optional(weight->unchecked<1>()) : std::nullopt;
253+
254+
for (py::ssize_t i = 0; i < localAxis.shape(0); ++i) {
255+
self.addConstraint(
256+
mm::FixedAxisDataT<float>(
257+
Eigen::Vector3f(localAxisAcc(i, 0), localAxisAcc(i, 1), localAxisAcc(i, 2)),
258+
Eigen::Vector3f(
259+
globalAxisAcc(i, 0), globalAxisAcc(i, 1), globalAxisAcc(i, 2)),
260+
parentAcc(i),
261+
weightAcc.has_value() ? (*weightAcc)(i) : 1.0f,
262+
name.has_value() ? name->at(i) : std::string{}));
263+
}
264+
},
265+
R"(Adds multiple fixed axis constraints to the error function.
266+
267+
:param local_axis: A numpy array of shape (n, 3) for the axes in the parent's coordinate frame.
268+
:param global_axis: A numpy array of shape (n, 3) for the target axes in the global frame.
269+
:param parent_index: A numpy array of size n for the indices of the parent joints.
270+
:param weight: A numpy array of size n for the weights of the constraints.
271+
:param name: An optional list of names for the constraints (for debugging).)",
272+
py::arg("local_axis"),
273+
py::arg("global_axis"),
274+
py::arg("parent_index"),
275+
py::arg("weight"),
276+
py::arg("name") = std::optional<std::vector<std::string>>{})
277+
.def(
278+
"clear_constraints",
279+
&FixedAxisErrorFunctionT::clearConstraints,
280+
"Clears all fixed axis constraints from the error function.");
281+
}
282+
283+
} // namespace
284+
285+
void addAimAxisErrorFunctions(py::module_& m) {
286+
// Aim error functions
287+
py::class_<mm::AimDataT<float>>(m, "AimData")
288+
.def(
289+
"__repr__",
290+
[](const mm::AimDataT<float>& self) {
291+
return fmt::format(
292+
"AimData(parent={}, weight={}, local_point=[{:.3f}, {:.3f}, {:.3f}], local_dir=[{:.3f}, {:.3f}, {:.3f}], global_target=[{:.3f}, {:.3f}, {:.3f}])",
293+
self.parent,
294+
self.weight,
295+
self.localPoint.x(),
296+
self.localPoint.y(),
297+
self.localPoint.z(),
298+
self.localDir.x(),
299+
self.localDir.y(),
300+
self.localDir.z(),
301+
self.globalTarget.x(),
302+
self.globalTarget.y(),
303+
self.globalTarget.z());
304+
})
305+
.def_readonly("parent", &mm::AimDataT<float>::parent, "The parent joint index")
306+
.def_readonly("weight", &mm::AimDataT<float>::weight, "The weight of the constraint")
307+
.def_readonly("local_point", &mm::AimDataT<float>::localPoint, "The origin of the local ray")
308+
.def_readonly("local_dir", &mm::AimDataT<float>::localDir, "The direction of the local ray")
309+
.def_readonly("global_target", &mm::AimDataT<float>::globalTarget, "The global aim target");
310+
311+
defAimErrorFunction<mm::AimDistErrorFunctionT<float>>(
312+
m,
313+
"AimDistErrorFunction",
314+
R"(The AimDistErrorFunction minimizes the distance between a ray (origin, direction) defined in
315+
joint-local space and the world-space target point. The residual is defined as the distance
316+
between the tartet position and its projection onto the ray. Note that the ray is only defined
317+
for positive t values, meaning that if the target point is _behind_ the ray origin, its
318+
projection will be at the ray origin where t=0.)");
319+
320+
defAimErrorFunction<mm::AimDirErrorFunctionT<float>>(m, "AimDirErrorFunction", R"(
321+
The AimDirErrorFunction minimizes the element-wise difference between a ray (origin, direction)
322+
defined in joint-local space and the normalized vector connecting the ray origin to the
323+
world-space target point. If the vector has near-zero length, the residual is set to zero to
324+
avoid divide-by-zero. )");
325+
326+
// Fixed axis error functions
327+
py::class_<mm::FixedAxisDataT<float>>(m, "FixedAxisData")
328+
.def(
329+
"__repr__",
330+
[](const mm::FixedAxisDataT<float>& self) {
331+
return fmt::format(
332+
"FixedAxisData(parent={}, weight={}, local_axis=[{:.3f}, {:.3f}, {:.3f}], global_axis=[{:.3f}, {:.3f}, {:.3f}])",
333+
self.parent,
334+
self.weight,
335+
self.localAxis.x(),
336+
self.localAxis.y(),
337+
self.localAxis.z(),
338+
self.globalAxis.x(),
339+
self.globalAxis.y(),
340+
self.globalAxis.z());
341+
})
342+
.def_readonly("parent", &mm::FixedAxisDataT<float>::parent, "The parent joint index")
343+
.def_readonly("weight", &mm::FixedAxisDataT<float>::weight, "The weight of the constraint")
344+
.def_readonly(
345+
"local_axis", &mm::FixedAxisDataT<float>::localAxis, "The local axis in parent space")
346+
.def_readonly("global_axis", &mm::FixedAxisDataT<float>::globalAxis, "The global axis");
347+
348+
defFixedAxisError<mm::FixedAxisDiffErrorFunctionT<float>>(
349+
m,
350+
"FixedAxisDiffErrorFunction",
351+
R"(Error function that minimizes the difference between a local axis (in
352+
joint-local space) and a global axis using element-wise differences.)");
353+
defFixedAxisError<mm::FixedAxisCosErrorFunctionT<float>>(
354+
m,
355+
"FixedAxisCosErrorFunction",
356+
R"(Error function that minimizes the difference between a local axis (in
357+
joint-local space) and a global axis using the cosine of the angle between
358+
the vectors (which is the same as the dot product of the vectors).)");
359+
defFixedAxisError<mm::FixedAxisAngleErrorFunctionT<float>>(
360+
m,
361+
"FixedAxisAngleErrorFunction",
362+
R"(Error function that minimizes the difference between a local axis (in
363+
joint-local space) and a global axis using the angle between the vectors.)");
364+
}
365+
366+
} // namespace pymomentum

0 commit comments

Comments
 (0)