14#include <lagrange/utils/warnoff.h>
15#include <tbb/enumerable_thread_specific.h>
16#include <tbb/parallel_for.h>
17#include <lagrange/utils/warnon.h>
21namespace experimental {
23template <
typename Derived>
24void ArrayBase::set(
const Eigen::MatrixBase<Derived>& data)
26 using EvalType = std::decay_t<
decltype(data.eval())>;
27 if (
auto ptr = down_cast<EigenArray<EvalType>*>()) {
30 auto ptr2 = down_cast<RawArray<
31 typename EvalType::Scalar,
32 EvalType::RowsAtCompileTime,
33 EvalType::ColsAtCompileTime,
34 EvalType::Options>*>()) {
36 }
else if (
auto ptr3 = down_cast<EigenArrayRef<EvalType>*>()) {
38 }
else if (is_compatible<EvalType>(
true)) {
41 if (EvalType::IsRowMajor == is_row_major()) {
42 resize(data.rows(), data.cols());
43 view<EvalType>() = data;
46 using TransposedType = std::decay_t<
decltype(data.transpose().eval())>;
48 resize(data.rows(), data.cols());
49 view<TransposedType>() = data;
53 "Unsupported type passed to ArrayBase::set(). Expecting {}",
58template <
typename Derived>
59void ArrayBase::set(Eigen::MatrixBase<Derived>&& data)
61 using EvalType = std::decay_t<
decltype(data.eval())>;
62 if (
auto ptr = down_cast<EigenArray<EvalType>*>()) {
63 ptr->set(std::move(data.derived()));
65 auto ptr2 = down_cast<RawArray<
66 typename EvalType::Scalar,
67 EvalType::RowsAtCompileTime,
68 EvalType::ColsAtCompileTime,
69 EvalType::Options>*>()) {
73 }
else if (
auto ptr3 = down_cast<EigenArrayRef<EvalType>*>()) {
74 ptr3->set(std::move(data.derived()));
75 }
else if (is_compatible<EvalType>(
true)) {
78 if (EvalType::IsRowMajor == is_row_major()) {
79 resize(data.rows(), data.cols());
80 view<EvalType>() = data;
83 using TransposedType = std::decay_t<
decltype(data.transpose().eval())>;
85 resize(data.rows(), data.cols());
86 view<TransposedType>() = std::move(data);
90 "Unsupported type passed to ArrayBase::set(). Expecting {}",
95template <
typename _TargetType>
98 using TargetType = std::decay_t<_TargetType>;
99 if (
auto ptr = down_cast<EigenArray<TargetType>*>()) {
100 return ptr->get_ref();
101 }
else if (
auto ptr2 = down_cast<EigenArrayRef<TargetType>*>()) {
102 return ptr2->get_ref();
105 "Unsupported type passed to ArrayBase::get(). Expecting {}",
110template <
typename _TargetType>
111const auto& ArrayBase::get()
const
113 using TargetType = std::decay_t<_TargetType>;
114 if (
auto ptr = down_cast<
const EigenArray<TargetType>*>()) {
115 return ptr->get_ref();
116 }
else if (
auto ptr3 = down_cast<
const EigenArrayRef<TargetType>*>()) {
117 return ptr3->get_ref();
118 }
else if (
auto ptr4 = down_cast<
const EigenArrayRef<const TargetType>*>()) {
119 return ptr4->get_ref();
122 "Unsupported type passed to const ArrayBase::get(). Expecting {}",
127template <
typename TargetType>
128bool ArrayBase::is_compatible(
bool ignore_storage_order)
const
130 using Scalar =
typename TargetType::Scalar;
131 constexpr int RowsAtCompileTime = TargetType::RowsAtCompileTime;
132 constexpr int ColsAtCompileTime = TargetType::ColsAtCompileTime;
133 constexpr int row_major = TargetType::IsRowMajor;
135 if (RowsAtCompileTime > 0 && this->rows() != RowsAtCompileTime)
return false;
136 if (ColsAtCompileTime > 0 && this->cols() != ColsAtCompileTime)
return false;
137 if (ScalarToEnum_v<Scalar> != m_scalar_type)
return false;
138 if (!ignore_storage_order && this->rows() != 1 && this->cols() != 1 &&
139 row_major != is_row_major()) {
141 "Target storage order ({}) does not match array storage order ({}).",
142 row_major == 1 ?
"RowMajor" :
"ColMajor",
143 is_row_major() == 1 ?
"RowMajor" :
"ColMajor");
151template <
typename Derived>
152std::unique_ptr<ArrayBase> ArrayBase::row_slice_impl(
153 const Eigen::MatrixBase<Derived>& matrix,
155 const IndexFunction& mapping_fn)
157 using Scalar =
typename Derived::Scalar;
158 using OutEigenType = Eigen::Matrix<
161 Derived::ColsAtCompileTime,
162 Derived::IsRowMajor ? Eigen::RowMajor : Eigen::ColMajor>;
164 const auto num_cols = matrix.cols();
165 OutEigenType out_matrix(num_rows, num_cols);
167 tbb::blocked_range<size_t>(0, num_rows),
168 [&](
const tbb::blocked_range<size_t>& r) {
169 for (
auto i = r.begin(); i != r.end(); i++) {
170 out_matrix.row(i) = matrix.row(mapping_fn(i));
174 return std::make_unique<EigenArray<OutEigenType>>(std::move(out_matrix));
177template <
typename Derived>
178std::enable_if_t<std::is_integral<typename Derived::Scalar>::value, std::unique_ptr<ArrayBase>>
179ArrayBase::row_slice_impl(
180 const Eigen::MatrixBase<Derived>& matrix,
182 const WeightedIndexFunction& mapping_fn)
184 using Scalar =
typename Derived::Scalar;
185 using OutEigenType = Eigen::Matrix<
188 Derived::ColsAtCompileTime,
189 Derived::IsRowMajor ? Eigen::RowMajor : Eigen::ColMajor>;
190 const auto num_cols = matrix.cols();
191 tbb::enumerable_thread_specific<std::vector<std::pair<Index, double>>> weights;
193 using DoubleEigenType = Eigen::Matrix<
196 Derived::ColsAtCompileTime,
197 Derived::IsRowMajor ? Eigen::RowMajor : Eigen::ColMajor>;
198 DoubleEigenType out_matrix(num_rows, num_cols);
199 out_matrix.setZero();
202 tbb::blocked_range<Index>(0, num_rows),
203 [&](
const tbb::blocked_range<Index>& r) {
204 for (
auto i = r.begin(); i != r.end(); i++) {
205 auto& entries = weights.local();
206 mapping_fn(i, entries);
207 for (const auto& entry : entries) {
209 matrix.row(entry.first).template cast<double>() * entry.second;
213 return std::make_unique<EigenArray<OutEigenType>>(
214 std::move(out_matrix.array().round().template cast<Scalar>().matrix().eval()));
217template <
typename Derived>
218std::enable_if_t<!std::is_integral<typename Derived::Scalar>::value, std::unique_ptr<ArrayBase>>
219ArrayBase::row_slice_impl(
220 const Eigen::MatrixBase<Derived>& matrix,
222 const WeightedIndexFunction& mapping_fn)
224 using Scalar =
typename Derived::Scalar;
225 using OutEigenType = Eigen::Matrix<
228 Derived::ColsAtCompileTime,
229 Derived::IsRowMajor ? Eigen::RowMajor : Eigen::ColMajor>;
230 const auto num_cols = matrix.cols();
231 tbb::enumerable_thread_specific<std::vector<std::pair<Index, double>>> weights;
233 OutEigenType out_matrix(num_rows, num_cols);
234 out_matrix.setZero();
237 tbb::blocked_range<Index>(0, num_rows),
238 [&](
const tbb::blocked_range<Index>& r) {
239 for (
auto i = r.begin(); i != r.end(); i++) {
240 auto& entries = weights.local();
241 mapping_fn(i, entries);
242 for (const auto& entry : entries) {
243 out_matrix.row(i) += matrix.row(entry.first).template cast<Scalar>() *
244 safe_cast<Scalar>(entry.second);
248 return std::make_unique<EigenArray<OutEigenType>>(std::move(out_matrix));
LA_CORE_API spdlog::logger & logger()
Retrieves the current logger.
Definition: Logger.cpp:40
#define la_runtime_assert(...)
Runtime assertion check.
Definition: assert.h:169
std::string string_format(fmt::format_string< Args... > format, Args &&... args)
Format args according to the format string fmt, and return the result as a string.
Definition: strings.h:103
Main namespace for Lagrange.
Definition: AABBIGL.h:30