Lagrange
Array.h
1/*
2 * Copyright 2020 Adobe. All rights reserved.
3 * This file is licensed to you under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License. You may obtain a copy
5 * of the License at http://www.apache.org/licenses/LICENSE-2.0
6 *
7 * Unless required by applicable law or agreed to in writing, software distributed under
8 * the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
9 * OF ANY KIND, either express or implied. See the License for the specific language
10 * governing permissions and limitations under the License.
11 */
12#pragma once
13
14#include <lagrange/Logger.h>
15#include <lagrange/common.h>
16#include <lagrange/experimental/Scalar.h>
17#include <lagrange/utils/assert.h>
18#include <lagrange/utils/range.h>
19#include <lagrange/utils/strings.h>
20
21#include <algorithm>
22#include <any>
23#include <exception>
24#include <iostream>
25
26namespace lagrange {
27namespace experimental {
28
29template <typename Data>
31{
32};
33
35{
36public:
37 using Index = Eigen::Index; // Default to std::ptrdiff_t
38
39public:
40 ArrayBase(ScalarEnum type)
41 : m_scalar_type(type)
42 {}
43
44 virtual ~ArrayBase() = default;
45
46public:
47 template <typename TargetType>
48 Eigen::Map<TargetType> view()
49 {
51 is_compatible<TargetType>(),
52 "Target view type is not compatible with the data.");
53 return Eigen::Map<TargetType>(data<typename TargetType::Scalar>(), rows(), cols());
54 }
55
56 template <typename TargetType>
57 Eigen::Map<const TargetType> view() const
58 {
60 is_compatible<TargetType>(),
61 "Target view type is not compatible with the data.");
62 return Eigen::Map<const TargetType>(data<typename TargetType::Scalar>(), rows(), cols());
63 }
64
65 virtual std::any get_type_info() const = 0;
66
67 template <typename Derived>
68 bool is_base_of() const
69 {
70 std::any info = get_type_info();
71 return std::any_cast<ArrayTypeInfo<Derived>>(&info) != nullptr;
72 }
73
74 template <typename DerivedPtr>
75 auto down_cast() -> std::add_pointer_t<std::decay_t<std::remove_pointer_t<DerivedPtr>>>
76 {
77 using Derived = std::decay_t<std::remove_pointer_t<DerivedPtr>>;
78 if (is_base_of<Derived>()) {
79 return static_cast<Derived*>(this);
80 }
81 return nullptr;
82 }
83
84 template <typename DerivedPtr>
85 auto down_cast() const
86 -> std::add_pointer_t<std::add_const_t<std::decay_t<std::remove_pointer_t<DerivedPtr>>>>
87 {
88 using Derived = std::decay_t<std::remove_pointer_t<DerivedPtr>>;
89 if (is_base_of<Derived>()) {
90 return static_cast<const Derived*>(this);
91 }
92 return nullptr;
93 }
94
95 template <typename Derived>
96 void set(const Eigen::MatrixBase<Derived>& data);
97
98 template <typename Derived>
99 void set(Eigen::MatrixBase<Derived>&& data);
100
101 ScalarEnum get_scalar_type() const { return m_scalar_type; }
102
103 template <typename TargetType>
104 const auto& get() const;
105
106 template <typename TargetType>
107 auto& get();
108
109 template <typename Scalar>
110 Scalar* data()
111 {
112 la_runtime_assert(ScalarToEnum_v<Scalar> == m_scalar_type);
113 return static_cast<Scalar*>(data());
114 }
115
116 template <typename Scalar>
117 const Scalar* data() const
118 {
119 la_runtime_assert(ScalarToEnum_v<Scalar> == m_scalar_type);
120 return static_cast<const Scalar*>(data());
121 }
122
123 virtual Index rows() const = 0;
124 virtual Index cols() const = 0;
125 virtual bool is_row_major() const = 0;
126 virtual void* data() = 0;
127 virtual const void* data() const = 0;
128 virtual void resize(Index, Index) = 0;
129 virtual std::unique_ptr<ArrayBase> clone() const = 0;
130
131 using IndexFunction = std::function<Index(Index)>;
132 using WeightedIndexFunction =
133 std::function<void(Index, std::vector<std::pair<Index, double>>&)>;
134
135 template <typename T>
136 std::unique_ptr<ArrayBase> row_slice(const std::vector<T>& row_indices) const
137 {
138 return row_slice(safe_cast<Index>(row_indices.size()), [&](Index i) {
139 return safe_cast<Index>(row_indices[i]);
140 });
141 }
142
149 virtual std::unique_ptr<ArrayBase> row_slice(Index num_rows, const IndexFunction& mapping_fn)
150 const = 0;
151
158 virtual std::unique_ptr<ArrayBase> row_slice(
159 Index num_rows,
160 const WeightedIndexFunction& mapping_fn) const = 0;
161
162 virtual std::string type_name() const = 0;
163
164protected:
165 template <typename TargetType>
166 bool is_compatible(bool ignore_storage_order = false) const;
167
168 template <typename Derived>
169 static std::unique_ptr<ArrayBase> row_slice_impl(
170 const Eigen::MatrixBase<Derived>& matrix,
171 Index num_rows,
172 const IndexFunction& mapping);
173
174 template <typename Derived>
175 static std::
176 enable_if_t<std::is_integral<typename Derived::Scalar>::value, std::unique_ptr<ArrayBase>>
177 row_slice_impl(
178 const Eigen::MatrixBase<Derived>& matrix,
179 Index num_rows,
180 const WeightedIndexFunction& mapping);
181
182 template <typename Derived>
183 static std::
184 enable_if_t<!std::is_integral<typename Derived::Scalar>::value, std::unique_ptr<ArrayBase>>
185 row_slice_impl(
186 const Eigen::MatrixBase<Derived>& matrix,
187 Index num_rows,
188 const WeightedIndexFunction& mapping);
189
190protected:
191 const ScalarEnum m_scalar_type;
192};
193
198template <typename _EigenType>
199class EigenArray : public ArrayBase
200{
201public:
202 EIGEN_MAKE_ALIGNED_OPERATOR_NEW
203 using EigenType = _EigenType;
205 using Index = ArrayBase::Index;
206 using ArrayBase::WeightedIndexFunction;
207 static_assert(
208 std::is_base_of<typename Eigen::EigenBase<EigenType>, EigenType>::value,
209 "Template parameter `_EigenType` is not an Eigen type!");
210 static_assert(!std::is_const<EigenType>::value, "EigenType should not be const type.");
211 static_assert(!std::is_reference<EigenType>::value, "EigenType should not be reference type.");
212
213public:
214 EigenArray()
215 : ArrayBase(ScalarToEnum_v<typename EigenType::Scalar>)
216 {}
217
218 template <typename T>
219 explicit EigenArray(const Eigen::MatrixBase<T>& data)
220 : ArrayBase(ScalarToEnum_v<typename EigenType::Scalar>)
221 , m_data(data)
222 {}
223
224 template <typename T>
225 explicit EigenArray(Eigen::MatrixBase<T>&& data)
226 : ArrayBase(ScalarToEnum_v<typename EigenType::Scalar>)
227 , m_data(std::move(data.derived()))
228 {}
229
230 std::any get_type_info() const override { return std::make_any<ArrayTypeInfo<Self>>(); }
231
232public:
233 EigenType& get_ref() { return m_data; }
234 const EigenType& get_ref() const { return m_data; }
235
236 template <typename Derived>
237 void set(Derived&& data)
238 {
239 m_data = std::forward<Derived>(data);
240 }
241
242 Index rows() const override { return m_data.rows(); }
243 Index cols() const override { return m_data.cols(); }
244 bool is_row_major() const override { return EigenType::IsRowMajor != 0; }
245 void* data() override { return static_cast<void*>(m_data.derived().data()); }
246 const void* data() const override { return static_cast<const void*>(m_data.derived().data()); }
247 void resize(Index r, Index c) override { m_data.resize(r, c); }
248
249 std::unique_ptr<ArrayBase> clone() const override
250 {
251 return std::make_unique<EigenArray<EigenType>>(m_data);
252 }
253
254 std::unique_ptr<ArrayBase> row_slice(Index num_rows, const IndexFunction& mapping_fn)
255 const override
256 {
257 return row_slice_impl(m_data, num_rows, mapping_fn);
258 }
259 std::unique_ptr<ArrayBase> row_slice(Index num_rows, const WeightedIndexFunction& mapping_fn)
260 const override
261 {
262 return row_slice_impl(m_data, num_rows, mapping_fn);
263 }
264
265 std::string type_name() const override
266 {
267 std::string scalar_name(ScalarToEnum<typename EigenType::Scalar>::name);
268 return string_format(
269 "EigenArray<Eigen::Matrix<{}, {}, {}, {}>>",
270 scalar_name,
271 static_cast<int>(EigenType::RowsAtCompileTime),
272 static_cast<int>(EigenType::ColsAtCompileTime),
273 static_cast<int>(EigenType::Options));
274 }
275
276private:
277 EigenType m_data;
278};
279
280
285template <typename _EigenType, bool IsConst = std::is_const<_EigenType>::value>
287
288template <typename _EigenType>
289class EigenArrayRef<_EigenType, false> : public ArrayBase
290{
291public:
292 using EigenType = _EigenType;
294 using DecayedEigenType = std::decay_t<_EigenType>;
295 using Index = ArrayBase::Index;
296 using ArrayBase::WeightedIndexFunction;
297 static_assert(
298 std::is_base_of<typename Eigen::EigenBase<DecayedEigenType>, DecayedEigenType>::value,
299 "Template parameter `_EigenType` is not an Eigen type!");
300 static_assert(!std::is_const<EigenType>::value, "EigenType should not be const type.");
301 static_assert(!std::is_reference<EigenType>::value, "EigenType should not be reference type.");
302
303public:
304 template <typename T>
305 explicit EigenArrayRef(Eigen::MatrixBase<T>& data)
306 : ArrayBase(ScalarToEnum_v<typename EigenType::Scalar>)
307 , m_data(data.derived())
308 {}
309
310 std::any get_type_info() const override { return std::make_any<ArrayTypeInfo<Self>>(); }
311
312public:
313 EigenType& get_ref() { return m_data; }
314 const EigenType& get_ref() const { return m_data; }
315
316 template <typename Derived>
317 void set(Derived&& data)
318 {
319 m_data = std::forward<Derived>(data);
320 }
321
322 Index rows() const override { return m_data.rows(); }
323 Index cols() const override { return m_data.cols(); }
324 bool is_row_major() const override { return EigenType::IsRowMajor != 0; }
325 void* data() override { return static_cast<void*>(m_data.derived().data()); }
326 const void* data() const override { return static_cast<const void*>(m_data.derived().data()); }
327 void resize(Index r, Index c) override { m_data.resize(r, c); }
328
329 std::unique_ptr<ArrayBase> clone() const override
330 {
331 return std::make_unique<EigenArray<DecayedEigenType>>(m_data);
332 }
333
334 std::unique_ptr<ArrayBase> row_slice(Index num_rows, const IndexFunction& mapping_fn)
335 const override
336 {
337 return row_slice_impl(m_data, num_rows, mapping_fn);
338 }
339 std::unique_ptr<ArrayBase> row_slice(Index num_rows, const WeightedIndexFunction& mapping_fn)
340 const override
341 {
342 return row_slice_impl(m_data, num_rows, mapping_fn);
343 }
344
345 std::string type_name() const override
346 {
347 std::string scalar_name(ScalarToEnum<typename EigenType::Scalar>::name);
348 return string_format(
349 "EigenArrayRef<Eigen::Matrix<{}, {}, {}, {}>>",
350 scalar_name,
351 static_cast<int>(EigenType::RowsAtCompileTime),
352 static_cast<int>(EigenType::ColsAtCompileTime),
353 static_cast<int>(EigenType::Options));
354 }
355
356private:
357 EigenType& m_data;
358};
359
360template <typename _EigenType>
361class EigenArrayRef<_EigenType, true> : public ArrayBase
362{
363public:
364 using EigenType = _EigenType;
366 using DecayedEigenType = std::decay_t<_EigenType>;
367 using Index = ArrayBase::Index;
368 using ArrayBase::WeightedIndexFunction;
369 static_assert(
370 std::is_base_of<typename Eigen::EigenBase<DecayedEigenType>, DecayedEigenType>::value,
371 "Template parameter `_EigenType` is not an Eigen type!");
372 static_assert(std::is_const<EigenType>::value, "EigenType must const type.");
373 static_assert(!std::is_reference<EigenType>::value, "EigenType should not be reference type.");
374
375public:
376 template <typename T>
377 explicit EigenArrayRef(const Eigen::MatrixBase<T>& data)
378 : ArrayBase(ScalarToEnum_v<typename EigenType::Scalar>)
379 , m_data(data.derived())
380 {}
381
382 std::any get_type_info() const override { return std::make_any<ArrayTypeInfo<Self>>(); }
383
384public:
385 const EigenType& get_ref() const { return m_data; }
386
387 Index rows() const override { return m_data.rows(); }
388 Index cols() const override { return m_data.cols(); }
389 bool is_row_major() const override { return EigenType::IsRowMajor != 0; }
390 const void* data() const override { return static_cast<const void*>(m_data.derived().data()); }
391 void* data() override { throw std::runtime_error("This method is not supported"); }
392
393 void resize(Index r, Index c) override
394 {
395 if (r != m_data.rows() || c != m_data.cols()) {
396 throw std::runtime_error("Resizing const EigenArrayRef is not allowed.");
397 }
398 }
399
400 std::unique_ptr<ArrayBase> clone() const override
401 {
402 return std::make_unique<EigenArray<DecayedEigenType>>(m_data);
403 }
404
405 std::unique_ptr<ArrayBase> row_slice(Index num_rows, const IndexFunction& mapping_fn)
406 const override
407 {
408 return row_slice_impl(m_data, num_rows, mapping_fn);
409 }
410 std::unique_ptr<ArrayBase> row_slice(Index num_rows, const WeightedIndexFunction& mapping_fn)
411 const override
412 {
413 return row_slice_impl(m_data, num_rows, mapping_fn);
414 }
415
416 std::string type_name() const override
417 {
418 std::string scalar_name(ScalarToEnum<typename EigenType::Scalar>::name);
419 return string_format(
420 "EigenArrayRef<const Eigen::Matrix<{}, {}, {}, {}>>",
421 scalar_name,
422 static_cast<int>(EigenType::RowsAtCompileTime),
423 static_cast<int>(EigenType::ColsAtCompileTime),
424 static_cast<int>(EigenType::Options));
425 }
426
427private:
428 EigenType& m_data;
429};
430
431
436template <
437 typename _Scalar,
438 int _Rows = Eigen::Dynamic,
439 int _Cols = Eigen::Dynamic,
440 int _Options = Eigen::RowMajor,
441 bool IsConst = std::is_const<_Scalar>::value>
443
444template <typename _Scalar, int _Rows, int _Cols, int _Options>
445class RawArray<_Scalar, _Rows, _Cols, _Options, false> : public ArrayBase
446{
447public:
449 using Scalar = _Scalar;
450 using Index = ArrayBase::Index;
451 using EigenType = Eigen::Matrix<Scalar, _Rows, _Cols, _Options>;
452 using EigenMap = Eigen::Map<EigenType>;
453 using ConstEigenMap = Eigen::Map<const EigenType>;
454 using ArrayBase::WeightedIndexFunction;
455
456public:
457 RawArray(Scalar* data, Index rows, Index cols)
458 : ArrayBase(ScalarToEnum_v<typename EigenType::Scalar>)
459 , m_data(data, rows, cols)
460 {}
461
462 std::any get_type_info() const override { return std::make_any<ArrayTypeInfo<Self>>(); }
463
464public:
465 EigenMap& get_ref() { return m_data; }
466 ConstEigenMap get_ref() const { return m_data; }
467
468 template <typename Derived>
469 void set(const Eigen::MatrixBase<Derived>& data)
470 {
471 m_data = data;
472 }
473
474 void set(Scalar* data, Index rows, Index cols) { m_data = EigenMap(data, rows, cols); }
475
476 Index rows() const override { return m_data.rows(); }
477 Index cols() const override { return m_data.cols(); }
478 bool is_row_major() const override { return EigenType::IsRowMajor != 0; }
479 void* data() override { return static_cast<void*>(m_data.data()); }
480 const void* data() const override { return static_cast<const void*>(m_data.data()); }
481 void resize(Index r, Index c) override
482 {
483 if (r != m_data.rows() || c != m_data.cols()) {
484 throw std::runtime_error("Resizing RawArray is not allowed.");
485 }
486 }
487
488 std::unique_ptr<ArrayBase> clone() const override
489 {
490 return std::make_unique<EigenArray<EigenType>>(m_data);
491 }
492
493 std::unique_ptr<ArrayBase> row_slice(Index num_rows, const IndexFunction& mapping_fn)
494 const override
495 {
496 return row_slice_impl(m_data, num_rows, mapping_fn);
497 }
498 std::unique_ptr<ArrayBase> row_slice(Index num_rows, const WeightedIndexFunction& mapping_fn)
499 const override
500 {
501 return row_slice_impl(m_data, num_rows, mapping_fn);
502 }
503
504 std::string type_name() const override
505 {
506 std::string scalar_name(ScalarToEnum<Scalar>::name);
507 return string_format("RawArray<{}, {}, {}, {}>", scalar_name, _Rows, _Cols, _Options);
508 }
509
510private:
511 EigenMap m_data;
512};
513
514template <typename _Scalar, int _Rows, int _Cols, int _Options>
515class RawArray<_Scalar, _Rows, _Cols, _Options, true> : public ArrayBase
516{
517public:
519 using Scalar = std::decay_t<_Scalar>;
520 using Index = ArrayBase::Index;
521 using EigenType = Eigen::Matrix<Scalar, _Rows, _Cols, _Options>;
522 using EigenMap = Eigen::Map<const EigenType>;
523 using ArrayBase::WeightedIndexFunction;
524
525public:
526 RawArray(const Scalar* data, Index rows, Index cols)
527 : ArrayBase(ScalarToEnum_v<typename EigenType::Scalar>)
528 , m_data(data, rows, cols)
529 {}
530
531 std::any get_type_info() const override { return std::make_any<ArrayTypeInfo<Self>>(); }
532
533public:
534 const EigenMap& get_ref() const { return m_data; }
535
536 Index rows() const override { return m_data.rows(); }
537 Index cols() const override { return m_data.cols(); }
538 bool is_row_major() const override { return EigenType::IsRowMajor != 0; }
539 const void* data() const override { return static_cast<const void*>(m_data.data()); }
540 void* data() override { throw std::runtime_error("This method is not supported"); }
541
542 void resize(Index r, Index c) override
543 {
544 if (r != m_data.rows() || c != m_data.cols()) {
545 throw std::runtime_error("Resizing RawArray is not allowed.");
546 }
547 }
548
549 std::unique_ptr<ArrayBase> clone() const override
550 {
551 return std::make_unique<EigenArray<EigenType>>(m_data);
552 }
553
554 std::unique_ptr<ArrayBase> row_slice(Index num_rows, const IndexFunction& mapping_fn)
555 const override
556 {
557 return row_slice_impl(m_data, num_rows, mapping_fn);
558 }
559 std::unique_ptr<ArrayBase> row_slice(Index num_rows, const WeightedIndexFunction& mapping_fn)
560 const override
561 {
562 return row_slice_impl(m_data, num_rows, mapping_fn);
563 }
564
565 std::string type_name() const override
566 {
567 std::string scalar_name(ScalarToEnum<Scalar>::name);
568 return string_format("RawArray<const {}, {}, {}, {}>", scalar_name, _Rows, _Cols, _Options);
569 }
570
571private:
572 const EigenMap m_data;
573};
574
575} // namespace experimental
576} // namespace lagrange
577
578#include "Array.impl.h"
579#include "Array.serialization.h"
Definition: Array.h:35
virtual std::unique_ptr< ArrayBase > row_slice(Index num_rows, const IndexFunction &mapping_fn) const =0
Using index function for row mapping.
virtual std::unique_ptr< ArrayBase > row_slice(Index num_rows, const WeightedIndexFunction &mapping_fn) const =0
This is the most generic version of row_slice method.
This class is a thin wrapper around an Eigen matrix.
Definition: Array.h:200
std::unique_ptr< ArrayBase > row_slice(Index num_rows, const IndexFunction &mapping_fn) const override
Using index function for row mapping.
Definition: Array.h:254
std::unique_ptr< ArrayBase > row_slice(Index num_rows, const WeightedIndexFunction &mapping_fn) const override
This is the most generic version of row_slice method.
Definition: Array.h:259
std::unique_ptr< ArrayBase > row_slice(Index num_rows, const IndexFunction &mapping_fn) const override
Using index function for row mapping.
Definition: Array.h:334
std::unique_ptr< ArrayBase > row_slice(Index num_rows, const WeightedIndexFunction &mapping_fn) const override
This is the most generic version of row_slice method.
Definition: Array.h:339
std::unique_ptr< ArrayBase > row_slice(Index num_rows, const IndexFunction &mapping_fn) const override
Using index function for row mapping.
Definition: Array.h:405
std::unique_ptr< ArrayBase > row_slice(Index num_rows, const WeightedIndexFunction &mapping_fn) const override
This is the most generic version of row_slice method.
Definition: Array.h:410
This class is a thin wrapper around an Eigen matrix.
Definition: Array.h:286
std::unique_ptr< ArrayBase > row_slice(Index num_rows, const IndexFunction &mapping_fn) const override
Using index function for row mapping.
Definition: Array.h:493
std::unique_ptr< ArrayBase > row_slice(Index num_rows, const WeightedIndexFunction &mapping_fn) const override
This is the most generic version of row_slice method.
Definition: Array.h:498
std::unique_ptr< ArrayBase > row_slice(Index num_rows, const IndexFunction &mapping_fn) const override
Using index function for row mapping.
Definition: Array.h:554
std::unique_ptr< ArrayBase > row_slice(Index num_rows, const WeightedIndexFunction &mapping_fn) const override
This is the most generic version of row_slice method.
Definition: Array.h:559
This class provide a thin wrapper around a raw array.
Definition: Array.h:442
#define la_runtime_assert(...)
Runtime assertion check.
Definition: assert.h:169
std::string string_format(fmt::format_string< Args... > format, Args &&... args)
Format args according to the format string fmt, and return the result as a string.
Definition: strings.h:103
Main namespace for Lagrange.
Definition: AABBIGL.h:30