Lagrange
Loading...
Searching...
No Matches
eigen_utils.h
1/*
2 * Copyright 2025 Adobe. All rights reserved.
3 * This file is licensed to you under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License. You may obtain a copy
5 * of the License at http://www.apache.org/licenses/LICENSE-2.0
6 *
7 * Unless required by applicable law or agreed to in writing, software distributed under
8 * the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
9 * OF ANY KIND, either express or implied. See the License for the specific language
10 * governing permissions and limitations under the License.
11 */
12#pragma once
13#include <lagrange/Logger.h>
14#include <lagrange/python/binding.h>
15
16#include <Eigen/Core>
17
18namespace lagrange::python {
19namespace nb = nanobind;
20
21template <typename Scalar, int Dim>
22using Point = Eigen::Matrix<Scalar, 1, Dim>;
23
24template <typename Scalar, int Dim>
25using NBPoint = nb::ndarray<Scalar, nb::shape<Dim>, nb::c_contig, nb::device::cpu>;
26
27template <typename Scalar, int Dim>
28using GenericPoint = std::variant<nb::list, NBPoint<Scalar, Dim>>;
29
30
37template <typename Scalar, int Dim>
38Point<Scalar, Dim> to_eigen_point(const GenericPoint<Scalar, Dim>& p)
39{
40 Point<Scalar, Dim> q;
41 if (std::holds_alternative<nb::list>(p)) {
42 auto lst = std::get<nb::list>(p);
43 if (lst.size() != Dim) {
44 throw std::runtime_error(fmt::format("Point list must have exactly {} elements.", Dim));
45 }
46 for (int i = 0; i < Dim; ++i) {
47 q(i) = nb::cast<Scalar>(lst[i]);
48 }
49 } else {
50 auto arr = std::get<NBPoint<Scalar, Dim>>(p);
51 for (int i = 0; i < Dim; ++i) {
52 q(i) = arr(i);
53 }
54 }
55 return q;
56}
57
58} // namespace lagrange::python