Lagrange
Array.impl.h
1/*
2 * Copyright 2020 Adobe. All rights reserved.
3 * This file is licensed to you under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License. You may obtain a copy
5 * of the License at http://www.apache.org/licenses/LICENSE-2.0
6 *
7 * Unless required by applicable law or agreed to in writing, software distributed under
8 * the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
9 * OF ANY KIND, either express or implied. See the License for the specific language
10 * governing permissions and limitations under the License.
11 */
12
13// clang-format off
14#include <lagrange/utils/warnoff.h>
15#include <tbb/enumerable_thread_specific.h>
16#include <tbb/parallel_for.h>
17#include <lagrange/utils/warnon.h>
18// clang-format on
19
20namespace lagrange {
21namespace experimental {
22
23template <typename Derived>
24void ArrayBase::set(const Eigen::MatrixBase<Derived>& data)
25{
26 using EvalType = std::decay_t<decltype(data.eval())>;
27 if (auto ptr = down_cast<EigenArray<EvalType>*>()) {
28 ptr->set(data);
29 } else if (
30 auto ptr2 = down_cast<RawArray<
31 typename EvalType::Scalar,
32 EvalType::RowsAtCompileTime,
33 EvalType::ColsAtCompileTime,
34 EvalType::Options>*>()) {
35 ptr2->set(data);
36 } else if (auto ptr3 = down_cast<EigenArrayRef<EvalType>*>()) {
37 ptr3->set(data);
38 } else if (is_compatible<EvalType>(true)) {
39 // The Derived type is not an exact match to the true type, fall back to
40 // brute force copying.
41 if (EvalType::IsRowMajor == is_row_major()) {
42 resize(data.rows(), data.cols());
43 view<EvalType>() = data;
44 } else {
45 // Type is compatible except storage order.
46 using TransposedType = std::decay_t<decltype(data.transpose().eval())>;
47 la_runtime_assert(is_compatible<TransposedType>());
48 resize(data.rows(), data.cols());
49 view<TransposedType>() = data;
50 }
51 } else {
52 throw std::runtime_error(string_format(
53 "Unsupported type passed to ArrayBase::set(). Expecting {}",
54 type_name()));
55 }
56}
57
58template <typename Derived>
59void ArrayBase::set(Eigen::MatrixBase<Derived>&& data)
60{
61 using EvalType = std::decay_t<decltype(data.eval())>;
62 if (auto ptr = down_cast<EigenArray<EvalType>*>()) {
63 ptr->set(std::move(data.derived()));
64 } else if (
65 auto ptr2 = down_cast<RawArray<
66 typename EvalType::Scalar,
67 EvalType::RowsAtCompileTime,
68 EvalType::ColsAtCompileTime,
69 EvalType::Options>*>()) {
70 // Fall back to copying because RawArray cannot change its data memory
71 // pointer.
72 ptr2->set(data);
73 } else if (auto ptr3 = down_cast<EigenArrayRef<EvalType>*>()) {
74 ptr3->set(std::move(data.derived()));
75 } else if (is_compatible<EvalType>(true)) {
76 // The Derived type is not an exact match to the true type, fall back to
77 // brute force copying.
78 if (EvalType::IsRowMajor == is_row_major()) {
79 resize(data.rows(), data.cols());
80 view<EvalType>() = data;
81 } else {
82 // Type is compatible except storage order.
83 using TransposedType = std::decay_t<decltype(data.transpose().eval())>;
84 la_runtime_assert(is_compatible<TransposedType>());
85 resize(data.rows(), data.cols());
86 view<TransposedType>() = std::move(data);
87 }
88 } else {
89 throw std::runtime_error(string_format(
90 "Unsupported type passed to ArrayBase::set(). Expecting {}",
91 type_name()));
92 }
93}
94
95template <typename _TargetType>
96auto& ArrayBase::get()
97{
98 using TargetType = std::decay_t<_TargetType>;
99 if (auto ptr = down_cast<EigenArray<TargetType>*>()) {
100 return ptr->get_ref();
101 } else if (auto ptr2 = down_cast<EigenArrayRef<TargetType>*>()) {
102 return ptr2->get_ref();
103 } else {
104 throw std::runtime_error(string_format(
105 "Unsupported type passed to ArrayBase::get(). Expecting {}",
106 type_name()));
107 }
108}
109
110template <typename _TargetType>
111const auto& ArrayBase::get() const
112{
113 using TargetType = std::decay_t<_TargetType>;
114 if (auto ptr = down_cast<const EigenArray<TargetType>*>()) {
115 return ptr->get_ref();
116 } else if (auto ptr3 = down_cast<const EigenArrayRef<TargetType>*>()) {
117 return ptr3->get_ref();
118 } else if (auto ptr4 = down_cast<const EigenArrayRef<const TargetType>*>()) {
119 return ptr4->get_ref();
120 } else {
121 throw std::runtime_error(string_format(
122 "Unsupported type passed to const ArrayBase::get(). Expecting {}",
123 type_name()));
124 }
125}
126
127template <typename TargetType>
128bool ArrayBase::is_compatible(bool ignore_storage_order) const
129{
130 using Scalar = typename TargetType::Scalar;
131 constexpr int RowsAtCompileTime = TargetType::RowsAtCompileTime;
132 constexpr int ColsAtCompileTime = TargetType::ColsAtCompileTime;
133 constexpr int row_major = TargetType::IsRowMajor;
134
135 if (RowsAtCompileTime > 0 && this->rows() != RowsAtCompileTime) return false;
136 if (ColsAtCompileTime > 0 && this->cols() != ColsAtCompileTime) return false;
137 if (ScalarToEnum_v<Scalar> != m_scalar_type) return false;
138 if (!ignore_storage_order && this->rows() != 1 && this->cols() != 1 &&
139 row_major != is_row_major()) {
140 lagrange::logger().error(
141 "Target storage order ({}) does not match array storage order ({}).",
142 row_major == 1 ? "RowMajor" : "ColMajor",
143 is_row_major() == 1 ? "RowMajor" : "ColMajor");
144 return false;
145 }
146
147 return true;
148}
149
150
151template <typename Derived>
152std::unique_ptr<ArrayBase> ArrayBase::row_slice_impl(
153 const Eigen::MatrixBase<Derived>& matrix,
154 Index num_rows,
155 const IndexFunction& mapping_fn)
156{
157 using Scalar = typename Derived::Scalar;
158 using OutEigenType = Eigen::Matrix<
159 Scalar,
160 Eigen::Dynamic,
161 Derived::ColsAtCompileTime,
162 Derived::IsRowMajor ? Eigen::RowMajor : Eigen::ColMajor>;
163
164 const auto num_cols = matrix.cols();
165 OutEigenType out_matrix(num_rows, num_cols);
166 tbb::parallel_for(
167 tbb::blocked_range<size_t>(0, num_rows),
168 [&](const tbb::blocked_range<size_t>& r) {
169 for (auto i = r.begin(); i != r.end(); i++) {
170 out_matrix.row(i) = matrix.row(mapping_fn(i));
171 }
172 });
173
174 return std::make_unique<EigenArray<OutEigenType>>(std::move(out_matrix));
175}
176
177template <typename Derived>
178std::enable_if_t<std::is_integral<typename Derived::Scalar>::value, std::unique_ptr<ArrayBase>>
179ArrayBase::row_slice_impl(
180 const Eigen::MatrixBase<Derived>& matrix,
181 Index num_rows,
182 const WeightedIndexFunction& mapping_fn)
183{
184 using Scalar = typename Derived::Scalar;
185 using OutEigenType = Eigen::Matrix<
186 Scalar,
187 Eigen::Dynamic,
188 Derived::ColsAtCompileTime,
189 Derived::IsRowMajor ? Eigen::RowMajor : Eigen::ColMajor>;
190 const auto num_cols = matrix.cols();
191 tbb::enumerable_thread_specific<std::vector<std::pair<Index, double>>> weights;
192
193 using DoubleEigenType = Eigen::Matrix<
194 double,
195 Eigen::Dynamic,
196 Derived::ColsAtCompileTime,
197 Derived::IsRowMajor ? Eigen::RowMajor : Eigen::ColMajor>;
198 DoubleEigenType out_matrix(num_rows, num_cols);
199 out_matrix.setZero();
200
201 tbb::parallel_for(
202 tbb::blocked_range<Index>(0, num_rows),
203 [&](const tbb::blocked_range<Index>& r) {
204 for (auto i = r.begin(); i != r.end(); i++) {
205 auto& entries = weights.local();
206 mapping_fn(i, entries);
207 for (const auto& entry : entries) {
208 out_matrix.row(i) +=
209 matrix.row(entry.first).template cast<double>() * entry.second;
210 }
211 }
212 });
213 return std::make_unique<EigenArray<OutEigenType>>(
214 std::move(out_matrix.array().round().template cast<Scalar>().matrix().eval()));
215}
216
217template <typename Derived>
218std::enable_if_t<!std::is_integral<typename Derived::Scalar>::value, std::unique_ptr<ArrayBase>>
219ArrayBase::row_slice_impl(
220 const Eigen::MatrixBase<Derived>& matrix,
221 Index num_rows,
222 const WeightedIndexFunction& mapping_fn)
223{
224 using Scalar = typename Derived::Scalar;
225 using OutEigenType = Eigen::Matrix<
226 Scalar,
227 Eigen::Dynamic,
228 Derived::ColsAtCompileTime,
229 Derived::IsRowMajor ? Eigen::RowMajor : Eigen::ColMajor>;
230 const auto num_cols = matrix.cols();
231 tbb::enumerable_thread_specific<std::vector<std::pair<Index, double>>> weights;
232
233 OutEigenType out_matrix(num_rows, num_cols);
234 out_matrix.setZero();
235
236 tbb::parallel_for(
237 tbb::blocked_range<Index>(0, num_rows),
238 [&](const tbb::blocked_range<Index>& r) {
239 for (auto i = r.begin(); i != r.end(); i++) {
240 auto& entries = weights.local();
241 mapping_fn(i, entries);
242 for (const auto& entry : entries) {
243 out_matrix.row(i) += matrix.row(entry.first).template cast<Scalar>() *
244 safe_cast<Scalar>(entry.second);
245 }
246 }
247 });
248 return std::make_unique<EigenArray<OutEigenType>>(std::move(out_matrix));
249}
250
251} // namespace experimental
252} // namespace lagrange
LA_CORE_API spdlog::logger & logger()
Retrieves the current logger.
Definition: Logger.cpp:40
#define la_runtime_assert(...)
Runtime assertion check.
Definition: assert.h:169
std::string string_format(fmt::format_string< Args... > format, Args &&... args)
Format args according to the format string fmt, and return the result as a string.
Definition: strings.h:103
Main namespace for Lagrange.
Definition: AABBIGL.h:30