Lagrange
Array.h
1/*
2 * Copyright 2020 Adobe. All rights reserved.
3 * This file is licensed to you under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License. You may obtain a copy
5 * of the License at http://www.apache.org/licenses/LICENSE-2.0
6 *
7 * Unless required by applicable law or agreed to in writing, software distributed under
8 * the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
9 * OF ANY KIND, either express or implied. See the License for the specific language
10 * governing permissions and limitations under the License.
11 */
12#pragma once
13
14#include <lagrange/Logger.h>
15#include <lagrange/common.h>
16#include <lagrange/experimental/Scalar.h>
17#include <lagrange/utils/assert.h>
18#include <lagrange/utils/build.h>
19#include <lagrange/utils/range.h>
20#include <lagrange/utils/strings.h>
21
22#include <algorithm>
23#include <any>
24#include <exception>
25#include <iostream>
26
27namespace lagrange {
28namespace experimental {
29
30template <typename Data>
32{
33};
34
36{
37public:
38 using Index = Eigen::Index; // Default to std::ptrdiff_t
39
40public:
41 ArrayBase(ScalarEnum type)
42 : m_scalar_type(type)
43 {}
44
45 virtual ~ArrayBase() = default;
46
47public:
48 template <typename TargetType>
49 Eigen::Map<TargetType> view()
50 {
52 is_compatible<TargetType>(),
53 "Target view type is not compatible with the data.");
54 return Eigen::Map<TargetType>(data<typename TargetType::Scalar>(), rows(), cols());
55 }
56
57 template <typename TargetType>
58 Eigen::Map<const TargetType> view() const
59 {
61 is_compatible<TargetType>(),
62 "Target view type is not compatible with the data.");
63 return Eigen::Map<const TargetType>(data<typename TargetType::Scalar>(), rows(), cols());
64 }
65
66 virtual std::any get_type_info() const = 0;
67
68 template <typename Derived>
69 bool is_base_of() const
70 {
71 std::any info = get_type_info();
72 return std::any_cast<ArrayTypeInfo<Derived>>(&info) != nullptr;
73 }
74
75 template <typename DerivedPtr>
76 auto down_cast() -> std::add_pointer_t<std::decay_t<std::remove_pointer_t<DerivedPtr>>>
77 {
78 using Derived = std::decay_t<std::remove_pointer_t<DerivedPtr>>;
79#if LAGRANGE_ENABLE_RTTI
80 return dynamic_cast<Derived*>(this);
81#else
82 if (is_base_of<Derived>()) {
83 return static_cast<Derived*>(this);
84 }
85 return nullptr;
86#endif
87 }
88
89 template <typename DerivedPtr>
90 auto down_cast() const
91 -> std::add_pointer_t<std::add_const_t<std::decay_t<std::remove_pointer_t<DerivedPtr>>>>
92 {
93 using Derived = std::decay_t<std::remove_pointer_t<DerivedPtr>>;
94#if LAGRANGE_ENABLE_RTTI
95 return dynamic_cast<const Derived*>(this);
96#else
97 if (is_base_of<Derived>()) {
98 return static_cast<const Derived*>(this);
99 }
100 return nullptr;
101#endif
102 }
103
104 template <typename Derived>
105 void set(const Eigen::MatrixBase<Derived>& data);
106
107 template <typename Derived>
108 void set(Eigen::MatrixBase<Derived>&& data);
109
110 ScalarEnum get_scalar_type() const { return m_scalar_type; }
111
112 template <typename TargetType>
113 const auto& get() const;
114
115 template <typename TargetType>
116 auto& get();
117
118 template <typename Scalar>
119 Scalar* data()
120 {
121 la_runtime_assert(ScalarToEnum_v<Scalar> == m_scalar_type);
122 return static_cast<Scalar*>(data());
123 }
124
125 template <typename Scalar>
126 const Scalar* data() const
127 {
128 la_runtime_assert(ScalarToEnum_v<Scalar> == m_scalar_type);
129 return static_cast<const Scalar*>(data());
130 }
131
132 virtual Index rows() const = 0;
133 virtual Index cols() const = 0;
134 virtual bool is_row_major() const = 0;
135 virtual void* data() = 0;
136 virtual const void* data() const = 0;
137 virtual void resize(Index, Index) = 0;
138 virtual std::unique_ptr<ArrayBase> clone() const = 0;
139
140 using IndexFunction = std::function<Index(Index)>;
141 using WeightedIndexFunction =
142 std::function<void(Index, std::vector<std::pair<Index, double>>&)>;
143
144 template <typename T>
145 std::unique_ptr<ArrayBase> row_slice(const std::vector<T>& row_indices) const
146 {
147 return row_slice(safe_cast<Index>(row_indices.size()), [&](Index i) {
148 return safe_cast<Index>(row_indices[i]);
149 });
150 }
151
158 virtual std::unique_ptr<ArrayBase> row_slice(Index num_rows, const IndexFunction& mapping_fn)
159 const = 0;
160
167 virtual std::unique_ptr<ArrayBase> row_slice(
168 Index num_rows,
169 const WeightedIndexFunction& mapping_fn) const = 0;
170
171 virtual std::string type_name() const = 0;
172
173protected:
174 template <typename TargetType>
175 bool is_compatible(bool ignore_storage_order = false) const;
176
177 template <typename Derived>
178 static std::unique_ptr<ArrayBase> row_slice_impl(
179 const Eigen::MatrixBase<Derived>& matrix,
180 Index num_rows,
181 const IndexFunction& mapping);
182
183 template <typename Derived>
184 static std::
185 enable_if_t<std::is_integral<typename Derived::Scalar>::value, std::unique_ptr<ArrayBase>>
186 row_slice_impl(
187 const Eigen::MatrixBase<Derived>& matrix,
188 Index num_rows,
189 const WeightedIndexFunction& mapping);
190
191 template <typename Derived>
192 static std::
193 enable_if_t<!std::is_integral<typename Derived::Scalar>::value, std::unique_ptr<ArrayBase>>
194 row_slice_impl(
195 const Eigen::MatrixBase<Derived>& matrix,
196 Index num_rows,
197 const WeightedIndexFunction& mapping);
198
199protected:
200 const ScalarEnum m_scalar_type;
201};
202
207template <typename _EigenType>
208class EigenArray : public ArrayBase
209{
210public:
211 EIGEN_MAKE_ALIGNED_OPERATOR_NEW
212 using EigenType = _EigenType;
214 using Index = ArrayBase::Index;
215 using ArrayBase::WeightedIndexFunction;
216 static_assert(
217 std::is_base_of<typename Eigen::EigenBase<EigenType>, EigenType>::value,
218 "Template parameter `_EigenType` is not an Eigen type!");
219 static_assert(!std::is_const<EigenType>::value, "EigenType should not be const type.");
220 static_assert(!std::is_reference<EigenType>::value, "EigenType should not be reference type.");
221
222public:
223 EigenArray()
224 : ArrayBase(ScalarToEnum_v<typename EigenType::Scalar>)
225 {}
226
227 template <typename T>
228 explicit EigenArray(const Eigen::MatrixBase<T>& data)
229 : ArrayBase(ScalarToEnum_v<typename EigenType::Scalar>)
230 , m_data(data)
231 {}
232
233 template <typename T>
234 explicit EigenArray(Eigen::MatrixBase<T>&& data)
235 : ArrayBase(ScalarToEnum_v<typename EigenType::Scalar>)
236 , m_data(std::move(data.derived()))
237 {}
238
239 std::any get_type_info() const override { return std::make_any<ArrayTypeInfo<Self>>(); }
240
241public:
242 EigenType& get_ref() { return m_data; }
243 const EigenType& get_ref() const { return m_data; }
244
245 template <typename Derived>
246 void set(Derived&& data)
247 {
248 m_data = std::forward<Derived>(data);
249 }
250
251 Index rows() const override { return m_data.rows(); }
252 Index cols() const override { return m_data.cols(); }
253 bool is_row_major() const override { return EigenType::IsRowMajor != 0; }
254 void* data() override { return static_cast<void*>(m_data.derived().data()); }
255 const void* data() const override { return static_cast<const void*>(m_data.derived().data()); }
256 void resize(Index r, Index c) override { m_data.resize(r, c); }
257
258 std::unique_ptr<ArrayBase> clone() const override
259 {
260 return std::make_unique<EigenArray<EigenType>>(m_data);
261 }
262
263 std::unique_ptr<ArrayBase> row_slice(Index num_rows, const IndexFunction& mapping_fn)
264 const override
265 {
266 return row_slice_impl(m_data, num_rows, mapping_fn);
267 }
268 std::unique_ptr<ArrayBase> row_slice(Index num_rows, const WeightedIndexFunction& mapping_fn)
269 const override
270 {
271 return row_slice_impl(m_data, num_rows, mapping_fn);
272 }
273
274 std::string type_name() const override
275 {
276 std::string scalar_name(ScalarToEnum<typename EigenType::Scalar>::name);
277 return string_format(
278 "EigenArray<Eigen::Matrix<{}, {}, {}, {}>>",
279 scalar_name,
280 static_cast<int>(EigenType::RowsAtCompileTime),
281 static_cast<int>(EigenType::ColsAtCompileTime),
282 static_cast<int>(EigenType::Options));
283 }
284
285private:
286 EigenType m_data;
287};
288
289
294template <typename _EigenType, bool IsConst = std::is_const<_EigenType>::value>
296
297template <typename _EigenType>
298class EigenArrayRef<_EigenType, false> : public ArrayBase
299{
300public:
301 using EigenType = _EigenType;
303 using DecayedEigenType = std::decay_t<_EigenType>;
304 using Index = ArrayBase::Index;
305 using ArrayBase::WeightedIndexFunction;
306 static_assert(
307 std::is_base_of<typename Eigen::EigenBase<DecayedEigenType>, DecayedEigenType>::value,
308 "Template parameter `_EigenType` is not an Eigen type!");
309 static_assert(!std::is_const<EigenType>::value, "EigenType should not be const type.");
310 static_assert(!std::is_reference<EigenType>::value, "EigenType should not be reference type.");
311
312public:
313 template <typename T>
314 explicit EigenArrayRef(Eigen::MatrixBase<T>& data)
315 : ArrayBase(ScalarToEnum_v<typename EigenType::Scalar>)
316 , m_data(data.derived())
317 {}
318
319 std::any get_type_info() const override { return std::make_any<ArrayTypeInfo<Self>>(); }
320
321public:
322 EigenType& get_ref() { return m_data; }
323 const EigenType& get_ref() const { return m_data; }
324
325 template <typename Derived>
326 void set(Derived&& data)
327 {
328 m_data = std::forward<Derived>(data);
329 }
330
331 Index rows() const override { return m_data.rows(); }
332 Index cols() const override { return m_data.cols(); }
333 bool is_row_major() const override { return EigenType::IsRowMajor != 0; }
334 void* data() override { return static_cast<void*>(m_data.derived().data()); }
335 const void* data() const override { return static_cast<const void*>(m_data.derived().data()); }
336 void resize(Index r, Index c) override { m_data.resize(r, c); }
337
338 std::unique_ptr<ArrayBase> clone() const override
339 {
340 return std::make_unique<EigenArray<DecayedEigenType>>(m_data);
341 }
342
343 std::unique_ptr<ArrayBase> row_slice(Index num_rows, const IndexFunction& mapping_fn)
344 const override
345 {
346 return row_slice_impl(m_data, num_rows, mapping_fn);
347 }
348 std::unique_ptr<ArrayBase> row_slice(Index num_rows, const WeightedIndexFunction& mapping_fn)
349 const override
350 {
351 return row_slice_impl(m_data, num_rows, mapping_fn);
352 }
353
354 std::string type_name() const override
355 {
356 std::string scalar_name(ScalarToEnum<typename EigenType::Scalar>::name);
357 return string_format(
358 "EigenArrayRef<Eigen::Matrix<{}, {}, {}, {}>>",
359 scalar_name,
360 static_cast<int>(EigenType::RowsAtCompileTime),
361 static_cast<int>(EigenType::ColsAtCompileTime),
362 static_cast<int>(EigenType::Options));
363 }
364
365private:
366 EigenType& m_data;
367};
368
369template <typename _EigenType>
370class EigenArrayRef<_EigenType, true> : public ArrayBase
371{
372public:
373 using EigenType = _EigenType;
375 using DecayedEigenType = std::decay_t<_EigenType>;
376 using Index = ArrayBase::Index;
377 using ArrayBase::WeightedIndexFunction;
378 static_assert(
379 std::is_base_of<typename Eigen::EigenBase<DecayedEigenType>, DecayedEigenType>::value,
380 "Template parameter `_EigenType` is not an Eigen type!");
381 static_assert(std::is_const<EigenType>::value, "EigenType must const type.");
382 static_assert(!std::is_reference<EigenType>::value, "EigenType should not be reference type.");
383
384public:
385 template <typename T>
386 explicit EigenArrayRef(const Eigen::MatrixBase<T>& data)
387 : ArrayBase(ScalarToEnum_v<typename EigenType::Scalar>)
388 , m_data(data.derived())
389 {}
390
391 std::any get_type_info() const override { return std::make_any<ArrayTypeInfo<Self>>(); }
392
393public:
394 const EigenType& get_ref() const { return m_data; }
395
396 Index rows() const override { return m_data.rows(); }
397 Index cols() const override { return m_data.cols(); }
398 bool is_row_major() const override { return EigenType::IsRowMajor != 0; }
399 const void* data() const override { return static_cast<const void*>(m_data.derived().data()); }
400 void* data() override { throw std::runtime_error("This method is not supported"); }
401
402 void resize(Index r, Index c) override
403 {
404 if (r != m_data.rows() || c != m_data.cols()) {
405 throw std::runtime_error("Resizing const EigenArrayRef is not allowed.");
406 }
407 }
408
409 std::unique_ptr<ArrayBase> clone() const override
410 {
411 return std::make_unique<EigenArray<DecayedEigenType>>(m_data);
412 }
413
414 std::unique_ptr<ArrayBase> row_slice(Index num_rows, const IndexFunction& mapping_fn)
415 const override
416 {
417 return row_slice_impl(m_data, num_rows, mapping_fn);
418 }
419 std::unique_ptr<ArrayBase> row_slice(Index num_rows, const WeightedIndexFunction& mapping_fn)
420 const override
421 {
422 return row_slice_impl(m_data, num_rows, mapping_fn);
423 }
424
425 std::string type_name() const override
426 {
427 std::string scalar_name(ScalarToEnum<typename EigenType::Scalar>::name);
428 return string_format(
429 "EigenArrayRef<const Eigen::Matrix<{}, {}, {}, {}>>",
430 scalar_name,
431 static_cast<int>(EigenType::RowsAtCompileTime),
432 static_cast<int>(EigenType::ColsAtCompileTime),
433 static_cast<int>(EigenType::Options));
434 }
435
436private:
437 EigenType& m_data;
438};
439
440
445template <
446 typename _Scalar,
447 int _Rows = Eigen::Dynamic,
448 int _Cols = Eigen::Dynamic,
449 int _Options = Eigen::RowMajor,
450 bool IsConst = std::is_const<_Scalar>::value>
452
453template <typename _Scalar, int _Rows, int _Cols, int _Options>
454class RawArray<_Scalar, _Rows, _Cols, _Options, false> : public ArrayBase
455{
456public:
458 using Scalar = _Scalar;
459 using Index = ArrayBase::Index;
460 using EigenType = Eigen::Matrix<Scalar, _Rows, _Cols, _Options>;
461 using EigenMap = Eigen::Map<EigenType>;
462 using ConstEigenMap = Eigen::Map<const EigenType>;
463 using ArrayBase::WeightedIndexFunction;
464
465public:
466 RawArray(Scalar* data, Index rows, Index cols)
467 : ArrayBase(ScalarToEnum_v<typename EigenType::Scalar>)
468 , m_data(data, rows, cols)
469 {}
470
471 std::any get_type_info() const override { return std::make_any<ArrayTypeInfo<Self>>(); }
472
473public:
474 EigenMap& get_ref() { return m_data; }
475 ConstEigenMap get_ref() const { return m_data; }
476
477 template <typename Derived>
478 void set(const Eigen::MatrixBase<Derived>& data)
479 {
480 m_data = data;
481 }
482
483 void set(Scalar* data, Index rows, Index cols) { m_data = EigenMap(data, rows, cols); }
484
485 Index rows() const override { return m_data.rows(); }
486 Index cols() const override { return m_data.cols(); }
487 bool is_row_major() const override { return EigenType::IsRowMajor != 0; }
488 void* data() override { return static_cast<void*>(m_data.data()); }
489 const void* data() const override { return static_cast<const void*>(m_data.data()); }
490 void resize(Index r, Index c) override
491 {
492 if (r != m_data.rows() || c != m_data.cols()) {
493 throw std::runtime_error("Resizing RawArray is not allowed.");
494 }
495 }
496
497 std::unique_ptr<ArrayBase> clone() const override
498 {
499 return std::make_unique<EigenArray<EigenType>>(m_data);
500 }
501
502 std::unique_ptr<ArrayBase> row_slice(Index num_rows, const IndexFunction& mapping_fn)
503 const override
504 {
505 return row_slice_impl(m_data, num_rows, mapping_fn);
506 }
507 std::unique_ptr<ArrayBase> row_slice(Index num_rows, const WeightedIndexFunction& mapping_fn)
508 const override
509 {
510 return row_slice_impl(m_data, num_rows, mapping_fn);
511 }
512
513 std::string type_name() const override
514 {
515 std::string scalar_name(ScalarToEnum<Scalar>::name);
516 return string_format("RawArray<{}, {}, {}, {}>", scalar_name, _Rows, _Cols, _Options);
517 }
518
519private:
520 EigenMap m_data;
521};
522
523template <typename _Scalar, int _Rows, int _Cols, int _Options>
524class RawArray<_Scalar, _Rows, _Cols, _Options, true> : public ArrayBase
525{
526public:
528 using Scalar = std::decay_t<_Scalar>;
529 using Index = ArrayBase::Index;
530 using EigenType = Eigen::Matrix<Scalar, _Rows, _Cols, _Options>;
531 using EigenMap = Eigen::Map<const EigenType>;
532 using ArrayBase::WeightedIndexFunction;
533
534public:
535 RawArray(const Scalar* data, Index rows, Index cols)
536 : ArrayBase(ScalarToEnum_v<typename EigenType::Scalar>)
537 , m_data(data, rows, cols)
538 {}
539
540 std::any get_type_info() const override { return std::make_any<ArrayTypeInfo<Self>>(); }
541
542public:
543 const EigenMap& get_ref() const { return m_data; }
544
545 Index rows() const override { return m_data.rows(); }
546 Index cols() const override { return m_data.cols(); }
547 bool is_row_major() const override { return EigenType::IsRowMajor != 0; }
548 const void* data() const override { return static_cast<const void*>(m_data.data()); }
549 void* data() override { throw std::runtime_error("This method is not supported"); }
550
551 void resize(Index r, Index c) override
552 {
553 if (r != m_data.rows() || c != m_data.cols()) {
554 throw std::runtime_error("Resizing RawArray is not allowed.");
555 }
556 }
557
558 std::unique_ptr<ArrayBase> clone() const override
559 {
560 return std::make_unique<EigenArray<EigenType>>(m_data);
561 }
562
563 std::unique_ptr<ArrayBase> row_slice(Index num_rows, const IndexFunction& mapping_fn)
564 const override
565 {
566 return row_slice_impl(m_data, num_rows, mapping_fn);
567 }
568 std::unique_ptr<ArrayBase> row_slice(Index num_rows, const WeightedIndexFunction& mapping_fn)
569 const override
570 {
571 return row_slice_impl(m_data, num_rows, mapping_fn);
572 }
573
574 std::string type_name() const override
575 {
576 std::string scalar_name(ScalarToEnum<Scalar>::name);
577 return string_format("RawArray<const {}, {}, {}, {}>", scalar_name, _Rows, _Cols, _Options);
578 }
579
580private:
581 const EigenMap m_data;
582};
583
584} // namespace experimental
585} // namespace lagrange
586
587#include "Array.impl.h"
588#include "Array.serialization.h"
Definition: Array.h:36
virtual std::unique_ptr< ArrayBase > row_slice(Index num_rows, const IndexFunction &mapping_fn) const =0
Using index function for row mapping.
virtual std::unique_ptr< ArrayBase > row_slice(Index num_rows, const WeightedIndexFunction &mapping_fn) const =0
This is the most generic version of row_slice method.
This class is a thin wrapper around an Eigen matrix.
Definition: Array.h:209
std::unique_ptr< ArrayBase > row_slice(Index num_rows, const IndexFunction &mapping_fn) const override
Using index function for row mapping.
Definition: Array.h:263
std::unique_ptr< ArrayBase > row_slice(Index num_rows, const WeightedIndexFunction &mapping_fn) const override
This is the most generic version of row_slice method.
Definition: Array.h:268
std::unique_ptr< ArrayBase > row_slice(Index num_rows, const IndexFunction &mapping_fn) const override
Using index function for row mapping.
Definition: Array.h:343
std::unique_ptr< ArrayBase > row_slice(Index num_rows, const WeightedIndexFunction &mapping_fn) const override
This is the most generic version of row_slice method.
Definition: Array.h:348
std::unique_ptr< ArrayBase > row_slice(Index num_rows, const IndexFunction &mapping_fn) const override
Using index function for row mapping.
Definition: Array.h:414
std::unique_ptr< ArrayBase > row_slice(Index num_rows, const WeightedIndexFunction &mapping_fn) const override
This is the most generic version of row_slice method.
Definition: Array.h:419
This class is a thin wrapper around an Eigen matrix.
Definition: Array.h:295
std::unique_ptr< ArrayBase > row_slice(Index num_rows, const IndexFunction &mapping_fn) const override
Using index function for row mapping.
Definition: Array.h:502
std::unique_ptr< ArrayBase > row_slice(Index num_rows, const WeightedIndexFunction &mapping_fn) const override
This is the most generic version of row_slice method.
Definition: Array.h:507
std::unique_ptr< ArrayBase > row_slice(Index num_rows, const IndexFunction &mapping_fn) const override
Using index function for row mapping.
Definition: Array.h:563
std::unique_ptr< ArrayBase > row_slice(Index num_rows, const WeightedIndexFunction &mapping_fn) const override
This is the most generic version of row_slice method.
Definition: Array.h:568
This class provide a thin wrapper around a raw array.
Definition: Array.h:451
#define la_runtime_assert(...)
Runtime assertion check.
Definition: assert.h:169
std::string string_format(fmt::format_string< Args... > format, Args &&... args)
Format args according to the format string fmt, and return the result as a string.
Definition: strings.h:103
Main namespace for Lagrange.
Definition: AABBIGL.h:30