Lagrange
All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Modules Pages
bind_attribute.h
1/*
2 * Copyright 2022 Adobe. All rights reserved.
3 * This file is licensed to you under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License. You may obtain a copy
5 * of the License at http://www.apache.org/licenses/LICENSE-2.0
6 *
7 * Unless required by applicable law or agreed to in writing, software distributed under
8 * the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
9 * OF ANY KIND, either express or implied. See the License for the specific language
10 * governing permissions and limitations under the License.
11 */
12#pragma once
13
14#include "PyAttribute.h"
15
16#include <lagrange/Attribute.h>
17#include <lagrange/AttributeValueType.h>
18#include <lagrange/Logger.h>
19#include <lagrange/internal/string_from_scalar.h>
20#include <lagrange/python/tensor_utils.h>
21#include <lagrange/utils/Error.h>
22#include <lagrange/utils/assert.h>
23#include <lagrange/utils/invalid.h>
24
25// clang-format off
26#include <lagrange/utils/warnoff.h>
27#include <nanobind/nanobind.h>
28#include <nanobind/stl/optional.h>
29#include <lagrange/utils/warnon.h>
30// clang-format on
31
32namespace lagrange::python {
33
34void bind_attribute(nanobind::module_& m)
35{
36 namespace nb = nanobind;
37 using namespace nb::literals;
38
39 auto attr_class = nb::class_<PyAttribute>(m, "Attribute", "Mesh attribute");
40 attr_class.def_prop_ro(
41 "element_type",
42 [](PyAttribute& self) { return self->get_element_type(); },
43 "Element type of the attribute.");
44 attr_class.def_prop_ro(
45 "usage",
46 [](PyAttribute& self) { return self->get_usage(); },
47 "Usage of the attribute.");
48 attr_class.def_prop_ro(
49 "num_channels",
50 [](PyAttribute& self) { return self->get_num_channels(); },
51 "Number of channels in the attribute.");
52
53 attr_class.def_prop_rw(
54 "default_value",
55 [](PyAttribute& self) {
56 return self.process([](auto& attr) { return nb::cast(attr.get_default_value()); });
57 },
58 [](PyAttribute& self, double val) {
59 self.process([&](auto& attr) {
60 using ValueType = typename std::decay_t<decltype(attr)>::ValueType;
61 attr.set_default_value(static_cast<ValueType>(val));
62 });
63 },
64 "Default value of the attribute.");
65 attr_class.def_prop_rw(
66 "growth_policy",
67 [](PyAttribute& self) {
68 return self.process([](auto& attr) { return attr.get_growth_policy(); });
69 },
70 [](PyAttribute& self, AttributeGrowthPolicy policy) {
71 self.process([&](auto& attr) { attr.set_growth_policy(policy); });
72 },
73 "Growth policy of the attribute.");
74 attr_class.def_prop_rw(
75 "shrink_policy",
76 [](PyAttribute& self) {
77 return self.process([](auto& attr) { return attr.get_shrink_policy(); });
78 },
79 [](PyAttribute& self, AttributeShrinkPolicy policy) {
80 self.process([&](auto& attr) { attr.set_shrink_policy(policy); });
81 },
82 "Shrink policy of the attribute.");
83 attr_class.def_prop_rw(
84 "write_policy",
85 [](PyAttribute& self) {
86 return self.process([](auto& attr) { return attr.get_write_policy(); });
87 },
88 [](PyAttribute& self, AttributeWritePolicy policy) {
89 self.process([&](auto& attr) { attr.set_write_policy(policy); });
90 },
91 "Write policy of the attribute.");
92 attr_class.def_prop_rw(
93 "copy_policy",
94 [](PyAttribute& self) {
95 return self.process([](auto& attr) { return attr.get_copy_policy(); });
96 },
97 [](PyAttribute& self, AttributeCopyPolicy policy) {
98 self.process([&](auto& attr) { attr.set_copy_policy(policy); });
99 },
100 "Copy policy of the attribute.");
101 attr_class.def_prop_rw(
102 "cast_policy",
103 [](PyAttribute& self) {
104 return self.process([](auto& attr) { return attr.get_cast_policy(); });
105 },
106 [](PyAttribute& self, AttributeCastPolicy policy) {
107 self.process([&](auto& attr) { attr.set_cast_policy(policy); });
108 },
109 "Copy policy of the attribute.");
110 attr_class.def(
111 "create_internal_copy",
112 [](PyAttribute& self) { self.process([](auto& attr) { attr.create_internal_copy(); }); },
113 "Create an internal copy of the attribute.");
114 attr_class.def(
115 "clear",
116 [](PyAttribute& self) { self.process([](auto& attr) { attr.clear(); }); },
117 "Clear the attribute so it has no elements.");
118 attr_class.def(
119 "reserve_entries",
120 [](PyAttribute& self, size_t s) {
121 self.process([=](auto& attr) { attr.reserve_entries(s); });
122 },
123 "num_entries"_a,
124 R"(Reserve enough memory for `num_entries` entries.
125
126:param num_entries: Number of entries to reserve. It does not need to be a multiple of `num_channels`.)");
127 attr_class.def(
128 "insert_elements",
129 [](PyAttribute& self, size_t num_elements) {
130 self.process([=](auto& attr) { attr.insert_elements(num_elements); });
131 },
132 "num_elements"_a,
133 R"(Insert new elements with default value to the attribute.
134
135:param num_elements: Number of elements to insert.)");
136 attr_class.def(
137 "insert_elements",
138 [](PyAttribute& self, nb::object value) {
139 auto insert_elements = [&](auto& attr) {
140 using ValueType = typename std::decay_t<decltype(attr)>::ValueType;
141 GenericTensor tensor;
142 std::vector<ValueType> buffer;
143 if (nb::try_cast(value, tensor)) {
144 if (tensor.dtype() != nb::dtype<ValueType>()) {
145 throw nb::type_error(fmt::format(
146 "Tensor has a unexpected dtype. Expecting {}.",
147 internal::string_from_scalar<ValueType>())
148 .c_str());
149 }
150 Tensor<ValueType> local_tensor(tensor.handle());
151 auto [data, shape, stride] = tensor_to_span(local_tensor);
152 la_runtime_assert(is_dense(shape, stride));
153 attr.insert_elements(data);
154 } else if (nb::try_cast(value, buffer)) {
155 attr.insert_elements({buffer.data(), buffer.size()});
156 } else {
157 throw nb::type_error("Argument `value` must be either list or np.ndarray.");
158 }
159 };
160 self.process(insert_elements);
161 },
162 "tensor"_a,
163 R"(Insert new elements to the attribute.
164
165:param tensor: A tensor with shape (num_elements, num_channels) or (num_elements,).)");
166 attr_class.def(
167 "empty",
168 [](PyAttribute& self) { return self.process([](auto& attr) { return attr.empty(); }); },
169 "Return true if the attribute is empty.");
170 attr_class.def_prop_ro(
171 "num_elements",
172 [](PyAttribute& self) {
173 return self.process([](auto& attr) { return attr.get_num_elements(); });
174 },
175 "Number of elements in the attribute.");
176 attr_class.def_prop_ro(
177 "external",
178 [](PyAttribute& self) {
179 return self.process([](auto& attr) { return attr.is_external(); });
180 },
181 "Return true if the attribute wraps external buffer.");
182 attr_class.def_prop_ro(
183 "readonly",
184 [](PyAttribute& self) {
185 return self.process([](auto& attr) { return attr.is_read_only(); });
186 },
187 "Return true if the attribute is read-only.");
188 attr_class.def_prop_rw(
189 "data",
190 [](PyAttribute& self) {
191 return self.process([&](auto& attr) {
192 auto tensor = attribute_to_tensor(attr, nb::find(&self));
193 return nb::cast(tensor, nb::rv_policy::reference_internal);
194 });
195 },
196 [](PyAttribute& self, nb::object value) {
197 auto wrap_tensor = [&](auto& attr) {
198 using ValueType = typename std::decay_t<decltype(attr)>::ValueType;
199 GenericTensor tensor;
200 std::vector<ValueType> buffer;
201 if (nb::try_cast(value, tensor)) {
202 if (tensor.dtype() != nb::dtype<ValueType>()) {
203 throw nb::type_error(fmt::format(
204 "Tensor has a unexpected dtype. Expecting {}.",
205 internal::string_from_scalar<ValueType>())
206 .c_str());
207 }
208 Tensor<ValueType> local_tensor(tensor.handle());
209 auto [data, shape, stride] = tensor_to_span(local_tensor);
210 la_runtime_assert(is_dense(shape, stride));
211 la_runtime_assert(shape.size() == 1 ? 1 == attr.get_num_channels() : true);
213 shape.size() == 2 ? shape[1] == attr.get_num_channels() : true);
214 const size_t num_elements = shape[0];
215
216 auto owner = std::make_shared<nb::object>(nb::find(tensor));
217 attr.wrap(make_shared_span(owner, data.data(), data.size()), num_elements);
218 } else if (nb::try_cast(value, buffer)) {
219 attr.clear();
220 attr.insert_elements({buffer.data(), buffer.size()});
221 } else {
222 throw nb::type_error("Attribute.data must be either list or np.ndarray.");
223 }
224 };
225 self.process(wrap_tensor);
226 },
227 nb::for_getter(nb::sig("def data(self, /) -> numpy.typing.NDArray")),
228 "Raw data buffer of the attribute.");
229 attr_class.def_prop_ro(
230 "dtype",
231 [](PyAttribute& self) -> std::optional<nb::type_object> {
232 auto np = nb::module_::import_("numpy");
233 switch (self.ptr()->get_value_type()) {
234 case AttributeValueType::e_int8_t: return np.attr("int8");
235 case AttributeValueType::e_int16_t: return np.attr("int16");
236 case AttributeValueType::e_int32_t: return np.attr("int32");
237 case AttributeValueType::e_int64_t: return np.attr("int64");
238 case AttributeValueType::e_uint8_t: return np.attr("uint8");
239 case AttributeValueType::e_uint16_t: return np.attr("uint16");
240 case AttributeValueType::e_uint32_t: return np.attr("uint32");
241 case AttributeValueType::e_uint64_t: return np.attr("uint64");
242 case AttributeValueType::e_float: return np.attr("float32");
243 case AttributeValueType::e_double: return np.attr("float64");
244 default: logger().warn("Attribute has an unknown dtype."); return std::nullopt;
245 }
246 },
247 "Value type of the attribute.");
248}
249
250} // namespace lagrange::python
LA_CORE_API spdlog::logger & logger()
Retrieves the current logger.
Definition: Logger.cpp:40
AttributeShrinkPolicy
Policy for shrinking external attribute buffers.
Definition: AttributeFwd.h:117
AttributeGrowthPolicy
Policy for growing external attribute buffers.
Definition: AttributeFwd.h:96
AttributeCopyPolicy
Policy for copying attribute that are views onto external buffers.
Definition: AttributeFwd.h:161
AttributeWritePolicy
Policy for attempting to write to read-only external buffers.
Definition: AttributeFwd.h:138
AttributeCastPolicy
Policy for remapping invalid values when casting to a different value type.
Definition: AttributeFwd.h:182
SurfaceMesh< ToScalar, ToIndex > cast(const SurfaceMesh< FromScalar, FromIndex > &source_mesh, const AttributeFilter &convertible_attributes={}, std::vector< std::string > *converted_attributes_names=nullptr)
Cast a mesh to a mesh of different scalar and/or index type.
#define la_runtime_assert(...)
Runtime assertion check.
Definition: assert.h:169
SharedSpan< T > make_shared_span(const std::shared_ptr< Y > &r, T *element_ptr, size_t size)
Created a SharedSpan object around an internal buffer of a parent object.
Definition: SharedSpan.h:101