|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * |
| 4 | + * This source code is licensed under the MIT license found in the |
| 5 | + * LICENSE file in the root directory of this source tree. |
| 6 | + */ |
| 7 | + |
| 8 | +#include <pymomentum/solver2/solver2_error_functions.h> |
| 9 | + |
| 10 | +#include <momentum/character/character.h> |
| 11 | +#include <momentum/character_solver/aim_error_function.h> |
| 12 | +#include <momentum/character_solver/fixed_axis_error_function.h> |
| 13 | +#include <pymomentum/solver2/solver2_utility.h> |
| 14 | + |
| 15 | +#include <fmt/format.h> |
| 16 | +#include <pybind11/eigen.h> |
| 17 | +#include <pybind11/pybind11.h> |
| 18 | +#include <pybind11/stl.h> |
| 19 | +#include <Eigen/Core> |
| 20 | + |
| 21 | +namespace py = pybind11; |
| 22 | +namespace mm = momentum; |
| 23 | + |
| 24 | +namespace pymomentum { |
| 25 | + |
| 26 | +namespace { |
| 27 | + |
| 28 | +template <typename AimErrorFunctionT> |
| 29 | +void defAimErrorFunction(py::module_& m, const char* name, const char* description) { |
| 30 | + py::class_<AimErrorFunctionT, mm::SkeletonErrorFunction, std::shared_ptr<AimErrorFunctionT>>( |
| 31 | + m, name, description) |
| 32 | + .def( |
| 33 | + "__repr__", |
| 34 | + [=](const AimErrorFunctionT& self) -> std::string { |
| 35 | + return fmt::format( |
| 36 | + "{}(weight={}, num_constraints={})", name, self.getWeight(), self.numConstraints()); |
| 37 | + }) |
| 38 | + .def( |
| 39 | + py::init<>( |
| 40 | + [](const mm::Character& character, float lossAlpha, float lossC, float weight) |
| 41 | + -> std::shared_ptr<AimErrorFunctionT> { |
| 42 | + validateWeight(weight, "weight"); |
| 43 | + auto result = std::make_shared<AimErrorFunctionT>(character, lossAlpha, lossC); |
| 44 | + result->setWeight(weight); |
| 45 | + return result; |
| 46 | + }), |
| 47 | + R"(Initialize error function. |
| 48 | +
|
| 49 | + :param character: The character to use. |
| 50 | + :param alpha: P-norm to use; 2 is Euclidean distance, 1 is L1 norm, 0 is Cauchy norm. |
| 51 | + :param c: c parameter in the generalized loss function, this is roughly equivalent to normalizing by the standard deviation in a Gaussian distribution. |
| 52 | + :param weight: The weight applied to the error function.)", |
| 53 | + py::keep_alive<1, 2>(), |
| 54 | + py::arg("character"), |
| 55 | + py::kw_only(), |
| 56 | + py::arg("alpha") = 2.0f, |
| 57 | + py::arg("c") = 1.0f, |
| 58 | + py::arg("weight") = 1.0f) |
| 59 | + .def_property_readonly( |
| 60 | + "constraints", |
| 61 | + [](const AimErrorFunctionT& self) { return self.getConstraints(); }, |
| 62 | + "Returns the list of aim constraints.") |
| 63 | + .def( |
| 64 | + "add_constraint", |
| 65 | + [](AimErrorFunctionT& self, |
| 66 | + const Eigen::Vector3f& localPoint, |
| 67 | + const Eigen::Vector3f& localDir, |
| 68 | + const Eigen::Vector3f& globalTarget, |
| 69 | + int parent, |
| 70 | + float weight, |
| 71 | + const std::string& name) { |
| 72 | + validateJointIndex(parent, "parent", self.getSkeleton()); |
| 73 | + validateWeight(weight, "weight"); |
| 74 | + self.addConstraint( |
| 75 | + mm::AimDataT<float>(localPoint, localDir, globalTarget, parent, weight, name)); |
| 76 | + }, |
| 77 | + R"(Adds an aim constraint to the error function. |
| 78 | +
|
| 79 | + :param local_point: The origin of the local ray. |
| 80 | + :param local_dir: The direction of the local ray. |
| 81 | + :param global_target: The global aim target. |
| 82 | + :param parent_index: The index of the parent joint. |
| 83 | + :param weight: The weight of the constraint. |
| 84 | + :param name: The name of the constraint (for debugging).)", |
| 85 | + py::arg("local_point"), |
| 86 | + py::arg("local_dir"), |
| 87 | + py::arg("global_target"), |
| 88 | + py::arg("parent"), |
| 89 | + py::arg("weight") = 1.0f, |
| 90 | + py::arg("name") = std::string{}) |
| 91 | + .def( |
| 92 | + "add_constraints", |
| 93 | + [](AimErrorFunctionT& self, |
| 94 | + const py::array_t<float>& localPoint, |
| 95 | + const py::array_t<float>& localDir, |
| 96 | + const py::array_t<float>& globalTarget, |
| 97 | + const py::array_t<int>& parent, |
| 98 | + const std::optional<py::array_t<float>>& weight, |
| 99 | + const std::optional<std::vector<std::string>>& name) { |
| 100 | + ArrayShapeValidator arrayValidator; |
| 101 | + const int nConsIdx = -1; |
| 102 | + arrayValidator.validate(localPoint, "local_point", {nConsIdx, 3}, {"n_cons", "xyz"}); |
| 103 | + arrayValidator.validate(localDir, "local_dir", {nConsIdx, 3}, {"n_cons", "xyz"}); |
| 104 | + arrayValidator.validate( |
| 105 | + globalTarget, "global_target", {nConsIdx, 3}, {"n_cons", "xyz"}); |
| 106 | + arrayValidator.validate(parent, "parent", {nConsIdx}, {"n_cons"}); |
| 107 | + validateJointIndex(parent, "parent", self.getSkeleton()); |
| 108 | + arrayValidator.validate(weight, "weight", {nConsIdx}, {"n_cons"}); |
| 109 | + |
| 110 | + if (name.has_value() && name->size() != parent.shape(0)) { |
| 111 | + throw std::runtime_error( |
| 112 | + fmt::format( |
| 113 | + "Invalid names; expected {} names but got {}", |
| 114 | + parent.shape(0), |
| 115 | + name->size())); |
| 116 | + } |
| 117 | + |
| 118 | + py::gil_scoped_release release; |
| 119 | + |
| 120 | + auto localPointAcc = localPoint.unchecked<2>(); |
| 121 | + auto localDirAcc = localDir.unchecked<2>(); |
| 122 | + auto globalTargetAcc = globalTarget.unchecked<2>(); |
| 123 | + auto parentAcc = parent.unchecked<1>(); |
| 124 | + auto weightAcc = |
| 125 | + weight.has_value() ? std::make_optional(weight->unchecked<1>()) : std::nullopt; |
| 126 | + |
| 127 | + for (py::ssize_t i = 0; i < localPoint.shape(0); ++i) { |
| 128 | + validateJointIndex(parentAcc(i), "parent", self.getSkeleton()); |
| 129 | + self.addConstraint( |
| 130 | + mm::AimDataT<float>( |
| 131 | + Eigen::Vector3f( |
| 132 | + localPointAcc(i, 0), localPointAcc(i, 1), localPointAcc(i, 2)), |
| 133 | + Eigen::Vector3f(localDirAcc(i, 0), localDirAcc(i, 1), localDirAcc(i, 2)), |
| 134 | + Eigen::Vector3f( |
| 135 | + globalTargetAcc(i, 0), globalTargetAcc(i, 1), globalTargetAcc(i, 2)), |
| 136 | + parentAcc(i), |
| 137 | + weightAcc.has_value() ? (*weightAcc)(i) : 1.0f, |
| 138 | + name.has_value() ? name->at(i) : std::string{})); |
| 139 | + } |
| 140 | + }, |
| 141 | + R"(Adds multiple aim constraints to the error function. |
| 142 | +
|
| 143 | + :param local_point: A numpy array of shape (n, 3) for the origins of the local rays. |
| 144 | + :param local_dir: A numpy array of shape (n, 3) for the directions of the local rays. |
| 145 | + :param global_target: A numpy array of shape (n, 3) for the global aim targets. |
| 146 | + :param parent_index: A numpy array of size n for the indices of the parent joints. |
| 147 | + :param weight: A numpy array of size n for the weights of the constraints. |
| 148 | + :param name: An optional list of names for the constraints (for debugging).)", |
| 149 | + py::arg("local_point"), |
| 150 | + py::arg("local_dir"), |
| 151 | + py::arg("global_target"), |
| 152 | + py::arg("parent_index"), |
| 153 | + py::arg("weight"), |
| 154 | + py::arg("name") = std::optional<std::vector<std::string>>{}) |
| 155 | + .def( |
| 156 | + "clear_constraints", |
| 157 | + &AimErrorFunctionT::clearConstraints, |
| 158 | + "Clears all aim constraints from the error function."); |
| 159 | +} |
| 160 | + |
| 161 | +template <typename FixedAxisErrorFunctionT> |
| 162 | +void defFixedAxisError(py::module_& m, const char* name, const char* description) { |
| 163 | + py::class_< |
| 164 | + FixedAxisErrorFunctionT, |
| 165 | + mm::SkeletonErrorFunctionT<float>, |
| 166 | + std::shared_ptr<FixedAxisErrorFunctionT>>(m, name, description) |
| 167 | + .def( |
| 168 | + "__repr__", |
| 169 | + [=](const FixedAxisErrorFunctionT& self) -> std::string { |
| 170 | + return fmt::format( |
| 171 | + "{}(weight={}, num_constraints={})", name, self.getWeight(), self.numConstraints()); |
| 172 | + }) |
| 173 | + .def( |
| 174 | + py::init<>( |
| 175 | + [](const mm::Character& character, float lossAlpha, float lossC, float weight) |
| 176 | + -> std::shared_ptr<FixedAxisErrorFunctionT> { |
| 177 | + validateWeight(weight, "weight"); |
| 178 | + auto result = |
| 179 | + std::make_shared<FixedAxisErrorFunctionT>(character, lossAlpha, lossC); |
| 180 | + result->setWeight(weight); |
| 181 | + return result; |
| 182 | + }), |
| 183 | + R"(Initialize a FixedAxisDiffErrorFunction. |
| 184 | +
|
| 185 | + :param character: The character to use. |
| 186 | + :param alpha: P-norm to use; 2 is Euclidean distance, 1 is L1 norm, 0 is Cauchy norm. |
| 187 | + :param c: c parameter in the generalized loss function, this is roughly equivalent to normalizing by the standard deviation in a Gaussian distribution. |
| 188 | + :param weight: The weight applied to the error function.)", |
| 189 | + py::keep_alive<1, 2>(), |
| 190 | + py::arg("character"), |
| 191 | + py::kw_only(), |
| 192 | + py::arg("alpha") = 2.0f, |
| 193 | + py::arg("c") = 1.0f, |
| 194 | + py::arg("weight") = 1.0f) |
| 195 | + .def( |
| 196 | + "add_constraint", |
| 197 | + [](FixedAxisErrorFunctionT& self, |
| 198 | + const Eigen::Vector3f& localAxis, |
| 199 | + const Eigen::Vector3f& globalAxis, |
| 200 | + int parent, |
| 201 | + float weight, |
| 202 | + const std::string& name) { |
| 203 | + validateJointIndex(parent, "parent", self.getSkeleton()); |
| 204 | + validateWeight(weight, "weight"); |
| 205 | + self.addConstraint( |
| 206 | + mm::FixedAxisDataT<float>(localAxis, globalAxis, parent, weight, name)); |
| 207 | + }, |
| 208 | + R"(Adds a fixed axis constraint to the error function. |
| 209 | +
|
| 210 | + :param local_axis: The axis in the parent's coordinate frame. |
| 211 | + :param global_axis: The target axis in the global frame. |
| 212 | + :param parent: The index of the parent joint. |
| 213 | + :param weight: The weight of the constraint. |
| 214 | + :param name: The name of the constraint (for debugging).)", |
| 215 | + py::arg("local_axis"), |
| 216 | + py::arg("global_axis"), |
| 217 | + py::arg("parent"), |
| 218 | + py::arg("weight") = 1.0f, |
| 219 | + py::arg("name") = std::string{}) |
| 220 | + .def_property_readonly( |
| 221 | + "constraints", |
| 222 | + [](const FixedAxisErrorFunctionT& self) { return self.getConstraints(); }, |
| 223 | + "Returns the list of fixed axis constraints.") |
| 224 | + .def( |
| 225 | + "add_constraints", |
| 226 | + [](FixedAxisErrorFunctionT& self, |
| 227 | + const py::array_t<float>& localAxis, |
| 228 | + const py::array_t<float>& globalAxis, |
| 229 | + const py::array_t<int>& parent, |
| 230 | + const std::optional<py::array_t<float>>& weight, |
| 231 | + const std::optional<std::vector<std::string>>& name) { |
| 232 | + ArrayShapeValidator validator; |
| 233 | + const int nConsIdx = -1; |
| 234 | + validator.validate(localAxis, "local_axis", {nConsIdx, 3}, {"n_cons", "xyz"}); |
| 235 | + validator.validate(globalAxis, "global_axis", {nConsIdx, 3}, {"n_cons", "xyz"}); |
| 236 | + validator.validate(parent, "parent", {nConsIdx}, {"n_cons"}); |
| 237 | + validateJointIndex(parent, "parent", self.getSkeleton()); |
| 238 | + validator.validate(weight, "weight", {nConsIdx}, {"n_cons"}); |
| 239 | + |
| 240 | + if (name.has_value() && name->size() != parent.shape(0)) { |
| 241 | + throw std::runtime_error( |
| 242 | + fmt::format( |
| 243 | + "Invalid names; expected {} names but got {}", |
| 244 | + parent.shape(0), |
| 245 | + name->size())); |
| 246 | + } |
| 247 | + |
| 248 | + auto localAxisAcc = localAxis.unchecked<2>(); |
| 249 | + auto globalAxisAcc = globalAxis.unchecked<2>(); |
| 250 | + auto parentAcc = parent.unchecked<1>(); |
| 251 | + auto weightAcc = |
| 252 | + weight.has_value() ? std::make_optional(weight->unchecked<1>()) : std::nullopt; |
| 253 | + |
| 254 | + for (py::ssize_t i = 0; i < localAxis.shape(0); ++i) { |
| 255 | + self.addConstraint( |
| 256 | + mm::FixedAxisDataT<float>( |
| 257 | + Eigen::Vector3f(localAxisAcc(i, 0), localAxisAcc(i, 1), localAxisAcc(i, 2)), |
| 258 | + Eigen::Vector3f( |
| 259 | + globalAxisAcc(i, 0), globalAxisAcc(i, 1), globalAxisAcc(i, 2)), |
| 260 | + parentAcc(i), |
| 261 | + weightAcc.has_value() ? (*weightAcc)(i) : 1.0f, |
| 262 | + name.has_value() ? name->at(i) : std::string{})); |
| 263 | + } |
| 264 | + }, |
| 265 | + R"(Adds multiple fixed axis constraints to the error function. |
| 266 | +
|
| 267 | + :param local_axis: A numpy array of shape (n, 3) for the axes in the parent's coordinate frame. |
| 268 | + :param global_axis: A numpy array of shape (n, 3) for the target axes in the global frame. |
| 269 | + :param parent_index: A numpy array of size n for the indices of the parent joints. |
| 270 | + :param weight: A numpy array of size n for the weights of the constraints. |
| 271 | + :param name: An optional list of names for the constraints (for debugging).)", |
| 272 | + py::arg("local_axis"), |
| 273 | + py::arg("global_axis"), |
| 274 | + py::arg("parent_index"), |
| 275 | + py::arg("weight"), |
| 276 | + py::arg("name") = std::optional<std::vector<std::string>>{}) |
| 277 | + .def( |
| 278 | + "clear_constraints", |
| 279 | + &FixedAxisErrorFunctionT::clearConstraints, |
| 280 | + "Clears all fixed axis constraints from the error function."); |
| 281 | +} |
| 282 | + |
| 283 | +} // namespace |
| 284 | + |
| 285 | +void addAimAxisErrorFunctions(py::module_& m) { |
| 286 | + // Aim error functions |
| 287 | + py::class_<mm::AimDataT<float>>(m, "AimData") |
| 288 | + .def( |
| 289 | + "__repr__", |
| 290 | + [](const mm::AimDataT<float>& self) { |
| 291 | + return fmt::format( |
| 292 | + "AimData(parent={}, weight={}, local_point=[{:.3f}, {:.3f}, {:.3f}], local_dir=[{:.3f}, {:.3f}, {:.3f}], global_target=[{:.3f}, {:.3f}, {:.3f}])", |
| 293 | + self.parent, |
| 294 | + self.weight, |
| 295 | + self.localPoint.x(), |
| 296 | + self.localPoint.y(), |
| 297 | + self.localPoint.z(), |
| 298 | + self.localDir.x(), |
| 299 | + self.localDir.y(), |
| 300 | + self.localDir.z(), |
| 301 | + self.globalTarget.x(), |
| 302 | + self.globalTarget.y(), |
| 303 | + self.globalTarget.z()); |
| 304 | + }) |
| 305 | + .def_readonly("parent", &mm::AimDataT<float>::parent, "The parent joint index") |
| 306 | + .def_readonly("weight", &mm::AimDataT<float>::weight, "The weight of the constraint") |
| 307 | + .def_readonly("local_point", &mm::AimDataT<float>::localPoint, "The origin of the local ray") |
| 308 | + .def_readonly("local_dir", &mm::AimDataT<float>::localDir, "The direction of the local ray") |
| 309 | + .def_readonly("global_target", &mm::AimDataT<float>::globalTarget, "The global aim target"); |
| 310 | + |
| 311 | + defAimErrorFunction<mm::AimDistErrorFunctionT<float>>( |
| 312 | + m, |
| 313 | + "AimDistErrorFunction", |
| 314 | + R"(The AimDistErrorFunction minimizes the distance between a ray (origin, direction) defined in |
| 315 | +joint-local space and the world-space target point. The residual is defined as the distance |
| 316 | +between the tartet position and its projection onto the ray. Note that the ray is only defined |
| 317 | +for positive t values, meaning that if the target point is _behind_ the ray origin, its |
| 318 | +projection will be at the ray origin where t=0.)"); |
| 319 | + |
| 320 | + defAimErrorFunction<mm::AimDirErrorFunctionT<float>>(m, "AimDirErrorFunction", R"( |
| 321 | +The AimDirErrorFunction minimizes the element-wise difference between a ray (origin, direction) |
| 322 | +defined in joint-local space and the normalized vector connecting the ray origin to the |
| 323 | +world-space target point. If the vector has near-zero length, the residual is set to zero to |
| 324 | +avoid divide-by-zero. )"); |
| 325 | + |
| 326 | + // Fixed axis error functions |
| 327 | + py::class_<mm::FixedAxisDataT<float>>(m, "FixedAxisData") |
| 328 | + .def( |
| 329 | + "__repr__", |
| 330 | + [](const mm::FixedAxisDataT<float>& self) { |
| 331 | + return fmt::format( |
| 332 | + "FixedAxisData(parent={}, weight={}, local_axis=[{:.3f}, {:.3f}, {:.3f}], global_axis=[{:.3f}, {:.3f}, {:.3f}])", |
| 333 | + self.parent, |
| 334 | + self.weight, |
| 335 | + self.localAxis.x(), |
| 336 | + self.localAxis.y(), |
| 337 | + self.localAxis.z(), |
| 338 | + self.globalAxis.x(), |
| 339 | + self.globalAxis.y(), |
| 340 | + self.globalAxis.z()); |
| 341 | + }) |
| 342 | + .def_readonly("parent", &mm::FixedAxisDataT<float>::parent, "The parent joint index") |
| 343 | + .def_readonly("weight", &mm::FixedAxisDataT<float>::weight, "The weight of the constraint") |
| 344 | + .def_readonly( |
| 345 | + "local_axis", &mm::FixedAxisDataT<float>::localAxis, "The local axis in parent space") |
| 346 | + .def_readonly("global_axis", &mm::FixedAxisDataT<float>::globalAxis, "The global axis"); |
| 347 | + |
| 348 | + defFixedAxisError<mm::FixedAxisDiffErrorFunctionT<float>>( |
| 349 | + m, |
| 350 | + "FixedAxisDiffErrorFunction", |
| 351 | + R"(Error function that minimizes the difference between a local axis (in |
| 352 | + joint-local space) and a global axis using element-wise differences.)"); |
| 353 | + defFixedAxisError<mm::FixedAxisCosErrorFunctionT<float>>( |
| 354 | + m, |
| 355 | + "FixedAxisCosErrorFunction", |
| 356 | + R"(Error function that minimizes the difference between a local axis (in |
| 357 | + joint-local space) and a global axis using the cosine of the angle between |
| 358 | + the vectors (which is the same as the dot product of the vectors).)"); |
| 359 | + defFixedAxisError<mm::FixedAxisAngleErrorFunctionT<float>>( |
| 360 | + m, |
| 361 | + "FixedAxisAngleErrorFunction", |
| 362 | + R"(Error function that minimizes the difference between a local axis (in |
| 363 | + joint-local space) and a global axis using the angle between the vectors.)"); |
| 364 | +} |
| 365 | + |
| 366 | +} // namespace pymomentum |
0 commit comments