Lagrange
Loading...
Searching...
No Matches
tensor_utils.h
1/*
2 * Copyright 2022 Adobe. All rights reserved.
3 * This file is licensed to you under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License. You may obtain a copy
5 * of the License at http://www.apache.org/licenses/LICENSE-2.0
6 *
7 * Unless required by applicable law or agreed to in writing, software distributed under
8 * the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
9 * OF ANY KIND, either express or implied. See the License for the specific language
10 * governing permissions and limitations under the License.
11 */
12#pragma once
13
14#include <lagrange/Logger.h>
15#include <lagrange/utils/span.h>
16
17#include <lagrange/Attribute.h>
18#include <lagrange/IndexedAttribute.h>
19#include <lagrange/python/binding.h>
20#include <lagrange/utils/SmallVector.h>
21
22#include <type_traits>
23#include <vector>
24
25namespace lagrange::python {
26namespace nb = nanobind;
27using namespace nb::literals;
28
29template <typename ValueType>
30using Tensor = nb::ndarray<ValueType, nb::numpy, nb::c_contig, nb::device::cpu>;
31using GenericTensor = nb::ndarray<nb::c_contig, nb::numpy, nb::device::cpu>;
32using Shape = SmallVector<size_t, 2>;
33using Stride = SmallVector<int64_t, 2>;
34
42bool is_vector(const Shape& shape);
43
52bool check_shape(const Shape& shape, size_t expected_size);
53
66bool check_shape(const Shape& shape, size_t expected_rows, size_t expected_cols);
67
77bool is_dense(const Shape& shape, const Stride& stride);
78
79
85template <typename ValueType>
86Tensor<ValueType> create_empty_tensor();
87
88template <typename ValueType>
89std::tuple<span<ValueType>, Shape, Stride> tensor_to_span(Tensor<ValueType> tensor);
90
91template <typename ValueType>
92Tensor<std::decay_t<ValueType>> span_to_tensor(span<ValueType> values, nb::handle base);
93
94template <typename ValueType>
95Tensor<std::decay_t<ValueType>>
96span_to_tensor(span<ValueType> values, span<const size_t> shape, nb::handle base);
97
98template <typename ValueType>
99Tensor<std::decay_t<ValueType>> span_to_tensor(
100 span<ValueType> values,
101 span<const size_t> shape,
102 span<const int64_t> stride,
103 nb::handle base);
104
105template <typename ValueType>
106Tensor<std::decay_t<ValueType>> attribute_to_tensor(
107 const Attribute<ValueType>& attr,
108 nb::handle base);
109
110template <typename ValueType>
111Tensor<std::decay_t<ValueType>>
112attribute_to_tensor(const Attribute<ValueType>& attr, span<const size_t> shape, nb::handle base);
113
114} // namespace lagrange::python
::nonstd::span< T, Extent > span
A bounds-safe view for sequences of objects.
Definition span.h:27