Lagrange
Loading...
Searching...
No Matches
bind_attribute.h
1/*
2 * Copyright 2022 Adobe. All rights reserved.
3 * This file is licensed to you under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License. You may obtain a copy
5 * of the License at http://www.apache.org/licenses/LICENSE-2.0
6 *
7 * Unless required by applicable law or agreed to in writing, software distributed under
8 * the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
9 * OF ANY KIND, either express or implied. See the License for the specific language
10 * governing permissions and limitations under the License.
11 */
12#pragma once
13
14#include "PyAttribute.h"
15
16#include <lagrange/Attribute.h>
17#include <lagrange/AttributeValueType.h>
18#include <lagrange/Logger.h>
19#include <lagrange/internal/string_from_scalar.h>
20#include <lagrange/python/binding.h>
21#include <lagrange/python/tensor_utils.h>
22#include <lagrange/utils/Error.h>
23#include <lagrange/utils/assert.h>
24#include <lagrange/utils/invalid.h>
25
26namespace lagrange::python {
27
28void bind_attribute(nanobind::module_& m)
29{
30 namespace nb = nanobind;
31 using namespace nb::literals;
32
33 auto attr_class = nb::class_<PyAttribute>(
34 m,
35 "Attribute",
36 "Attribute data associated with mesh elements (vertices, facets, corners, edges).");
37 attr_class.def_prop_ro(
38 "element_type",
39 [](PyAttribute& self) { return self->get_element_type(); },
40 "Element type (Vertex, Facet, Corner, Edge, Value).");
41 attr_class.def_prop_ro(
42 "usage",
43 [](PyAttribute& self) { return self->get_usage(); },
44 "Usage type (Position, Normal, UV, Color, etc.).");
45 attr_class.def_prop_ro(
46 "num_channels",
47 [](PyAttribute& self) { return self->get_num_channels(); },
48 "Number of channels per element.");
49
50 attr_class.def_prop_rw(
51 "default_value",
52 [](PyAttribute& self) {
53 return self.process([](auto& attr) { return nb::cast(attr.get_default_value()); });
54 },
55 [](PyAttribute& self, double val) {
56 self.process([&](auto& attr) {
57 using ValueType = typename std::decay_t<decltype(attr)>::ValueType;
58 attr.set_default_value(static_cast<ValueType>(val));
59 });
60 },
61 "Default value for new elements.");
62 attr_class.def_prop_rw(
63 "growth_policy",
64 [](PyAttribute& self) {
65 return self.process([](auto& attr) { return attr.get_growth_policy(); });
66 },
67 [](PyAttribute& self, AttributeGrowthPolicy policy) {
68 self.process([&](auto& attr) { attr.set_growth_policy(policy); });
69 },
70 "Policy for growing the attribute when elements are added.");
71 attr_class.def_prop_rw(
72 "shrink_policy",
73 [](PyAttribute& self) {
74 return self.process([](auto& attr) { return attr.get_shrink_policy(); });
75 },
76 [](PyAttribute& self, AttributeShrinkPolicy policy) {
77 self.process([&](auto& attr) { attr.set_shrink_policy(policy); });
78 },
79 "Policy for shrinking the attribute when elements are removed.");
80 attr_class.def_prop_rw(
81 "write_policy",
82 [](PyAttribute& self) {
83 return self.process([](auto& attr) { return attr.get_write_policy(); });
84 },
85 [](PyAttribute& self, AttributeWritePolicy policy) {
86 self.process([&](auto& attr) { attr.set_write_policy(policy); });
87 },
88 "Policy for write operations on the attribute.");
89 attr_class.def_prop_rw(
90 "copy_policy",
91 [](PyAttribute& self) {
92 return self.process([](auto& attr) { return attr.get_copy_policy(); });
93 },
94 [](PyAttribute& self, AttributeCopyPolicy policy) {
95 self.process([&](auto& attr) { attr.set_copy_policy(policy); });
96 },
97 "Policy for copying the attribute.");
98 attr_class.def_prop_rw(
99 "cast_policy",
100 [](PyAttribute& self) {
101 return self.process([](auto& attr) { return attr.get_cast_policy(); });
102 },
103 [](PyAttribute& self, AttributeCastPolicy policy) {
104 self.process([&](auto& attr) { attr.set_cast_policy(policy); });
105 },
106 "Policy for casting the attribute to different types.");
107 attr_class.def(
108 "create_internal_copy",
109 [](PyAttribute& self) { self.process([](auto& attr) { attr.create_internal_copy(); }); },
110 "Create an internal copy if the attribute wraps external data.");
111 attr_class.def(
112 "clear",
113 [](PyAttribute& self) { self.process([](auto& attr) { attr.clear(); }); },
114 "Remove all elements from the attribute.");
115 attr_class.def(
116 "reserve_entries",
117 [](PyAttribute& self, size_t s) {
118 self.process([=](auto& attr) { attr.reserve_entries(s); });
119 },
120 "num_entries"_a,
121 R"(Reserve enough memory for `num_entries` entries.
122
123:param num_entries: Number of entries to reserve. It does not need to be a multiple of `num_channels`.)");
124 attr_class.def(
125 "insert_elements",
126 [](PyAttribute& self, size_t num_elements) {
127 self.process([=](auto& attr) { attr.insert_elements(num_elements); });
128 },
129 "num_elements"_a,
130 R"(Insert new elements with default value to the attribute.
131
132:param num_elements: Number of elements to insert.)");
133 attr_class.def(
134 "insert_elements",
135 [](PyAttribute& self, nb::object value) {
136 auto insert_elements = [&](auto& attr) {
137 using ValueType = typename std::decay_t<decltype(attr)>::ValueType;
138 GenericTensor tensor;
139 std::vector<ValueType> buffer;
140 if (nb::try_cast(value, tensor)) {
141 if (tensor.dtype() != nb::dtype<ValueType>()) {
142 throw nb::type_error(
143 fmt::format(
144 "Tensor has a unexpected dtype. Expecting {}.",
146 .c_str());
147 }
148 Tensor<ValueType> local_tensor(tensor.handle());
149 auto [data, shape, stride] = tensor_to_span(local_tensor);
150 la_runtime_assert(is_dense(shape, stride));
151 attr.insert_elements(data);
152 } else if (nb::try_cast(value, buffer)) {
153 attr.insert_elements({buffer.data(), buffer.size()});
154 } else {
155 throw nb::type_error("Argument `value` must be either list or np.ndarray.");
156 }
157 };
158 self.process(insert_elements);
159 },
160 "tensor"_a,
161 R"(Insert new elements to the attribute.
162
163:param tensor: A tensor with shape (num_elements, num_channels) or (num_elements,).)");
164 attr_class.def(
165 "empty",
166 [](PyAttribute& self) { return self.process([](auto& attr) { return attr.empty(); }); },
167 "Check if the attribute has no elements.");
168 attr_class.def_prop_ro(
169 "num_elements",
170 [](PyAttribute& self) {
171 return self.process([](auto& attr) { return attr.get_num_elements(); });
172 },
173 "Number of elements in the attribute.");
174 attr_class.def_prop_ro(
175 "external",
176 [](PyAttribute& self) {
177 return self.process([](auto& attr) { return attr.is_external(); });
178 },
179 "Check if the attribute wraps external data.");
180 attr_class.def_prop_ro(
181 "readonly",
182 [](PyAttribute& self) {
183 return self.process([](auto& attr) { return attr.is_read_only(); });
184 },
185 "Check if the attribute is read-only.");
186 attr_class.def_prop_rw(
187 "data",
188 [](PyAttribute& self) {
189 return self.process([&](auto& attr) {
190 auto tensor = attribute_to_tensor(attr, nb::find(&self));
191 return nb::cast(tensor, nb::rv_policy::reference_internal);
192 });
193 },
194 [](PyAttribute& self, nb::object value) {
195 auto wrap_tensor = [&](auto& attr) {
196 using ValueType = typename std::decay_t<decltype(attr)>::ValueType;
197 GenericTensor tensor;
198 std::vector<ValueType> buffer;
199 if (nb::try_cast(value, tensor)) {
200 if (tensor.dtype() != nb::dtype<ValueType>()) {
201 throw nb::type_error(
202 fmt::format(
203 "Tensor has a unexpected dtype. Expecting {}.",
205 .c_str());
206 }
207 Tensor<ValueType> local_tensor(tensor.handle());
208 auto [data, shape, stride] = tensor_to_span(local_tensor);
209 la_runtime_assert(is_dense(shape, stride));
210 la_runtime_assert(shape.size() == 1 ? 1 == attr.get_num_channels() : true);
212 shape.size() == 2 ? shape[1] == attr.get_num_channels() : true);
213 const size_t num_elements = shape[0];
214
215 auto owner = std::make_shared<nb::object>(nb::find(tensor));
216 attr.wrap(make_shared_span(owner, data.data(), data.size()), num_elements);
217 } else if (nb::try_cast(value, buffer)) {
218 attr.clear();
219 attr.insert_elements({buffer.data(), buffer.size()});
220 } else {
221 throw nb::type_error("Attribute.data must be either list or np.ndarray.");
222 }
223 };
224 self.process(wrap_tensor);
225 },
226 nb::for_getter(nb::sig("def data(self, /) -> numpy.typing.NDArray")),
227 "Raw data as a numpy array.");
228 attr_class.def_prop_ro(
229 "dtype",
230 [](PyAttribute& self) -> std::optional<nb::type_object> {
231 auto np = nb::module_::import_("numpy");
232 switch (self.ptr()->get_value_type()) {
233 case AttributeValueType::e_int8_t: return np.attr("int8");
234 case AttributeValueType::e_int16_t: return np.attr("int16");
235 case AttributeValueType::e_int32_t: return np.attr("int32");
236 case AttributeValueType::e_int64_t: return np.attr("int64");
237 case AttributeValueType::e_uint8_t: return np.attr("uint8");
238 case AttributeValueType::e_uint16_t: return np.attr("uint16");
239 case AttributeValueType::e_uint32_t: return np.attr("uint32");
240 case AttributeValueType::e_uint64_t: return np.attr("uint64");
241 case AttributeValueType::e_float: return np.attr("float32");
242 case AttributeValueType::e_double: return np.attr("float64");
243 default: logger().warn("Attribute has an unknown dtype."); return std::nullopt;
244 }
245 },
246 "NumPy dtype of the attribute values.");
247}
248
249} // namespace lagrange::python
Definition PyAttribute.h:24
LA_CORE_API spdlog::logger & logger()
Retrieves the current logger.
Definition Logger.cpp:40
AttributeShrinkPolicy
Policy for shrinking external attribute buffers.
Definition AttributeFwd.h:117
AttributeGrowthPolicy
Policy for growing external attribute buffers.
Definition AttributeFwd.h:96
AttributeCopyPolicy
Policy for copying attribute that are views onto external buffers.
Definition AttributeFwd.h:161
AttributeWritePolicy
Policy for attempting to write to read-only external buffers.
Definition AttributeFwd.h:138
AttributeCastPolicy
Policy for remapping invalid values when casting to a different value type.
Definition AttributeFwd.h:182
#define la_runtime_assert(...)
Runtime assertion check.
Definition assert.h:174
SharedSpan< T > make_shared_span(const std::shared_ptr< Y > &r, T *element_ptr, size_t size)
Created a SharedSpan object around an internal buffer of a parent object.
Definition SharedSpan.h:101
std::string_view string_from_scalar()
Returns a human-readable string from any supported attribute value type.