Lagrange
Loading...
Searching...
No Matches
eigen_utils.h
1/*
2 * Copyright 2025 Adobe. All rights reserved.
3 * This file is licensed to you under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License. You may obtain a copy
5 * of the License at http://www.apache.org/licenses/LICENSE-2.0
6 *
7 * Unless required by applicable law or agreed to in writing, software distributed under
8 * the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
9 * OF ANY KIND, either express or implied. See the License for the specific language
10 * governing permissions and limitations under the License.
11 */
12#pragma once
13#include <lagrange/Logger.h>
14#include <lagrange/python/binding.h>
15#include <lagrange/utils/fmt/format.h>
16
17#include <Eigen/Core>
18
19namespace lagrange::python {
20namespace nb = nanobind;
21
22template <typename Scalar, int Dim>
23using Point = Eigen::Matrix<Scalar, 1, Dim>;
24
25template <typename Scalar, int Dim>
26using NBPoint = nb::ndarray<Scalar, nb::shape<Dim>, nb::c_contig, nb::device::cpu>;
27
28template <typename Scalar, int Dim>
29using GenericPoint = std::variant<nb::list, NBPoint<Scalar, Dim>>;
30
31
38template <typename Scalar, int Dim>
39Point<Scalar, Dim> to_eigen_point(const GenericPoint<Scalar, Dim>& p)
40{
41 Point<Scalar, Dim> q;
42 if (std::holds_alternative<nb::list>(p)) {
43 auto lst = std::get<nb::list>(p);
44 if (lst.size() != Dim) {
45 throw std::runtime_error(
46 lagrange::format("Point list must have exactly {} elements.", Dim));
47 }
48 for (int i = 0; i < Dim; ++i) {
49 q(i) = nb::cast<Scalar>(lst[i]);
50 }
51 } else {
52 auto arr = std::get<NBPoint<Scalar, Dim>>(p);
53 for (int i = 0; i < Dim; ++i) {
54 q(i) = arr(i);
55 }
56 }
57 return q;
58}
59
60} // namespace lagrange::python