Lagrange
Loading...
Searching...
No Matches
bind_attribute.h
1/*
2 * Copyright 2022 Adobe. All rights reserved.
3 * This file is licensed to you under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License. You may obtain a copy
5 * of the License at http://www.apache.org/licenses/LICENSE-2.0
6 *
7 * Unless required by applicable law or agreed to in writing, software distributed under
8 * the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
9 * OF ANY KIND, either express or implied. See the License for the specific language
10 * governing permissions and limitations under the License.
11 */
12#pragma once
13
14#include "PyAttribute.h"
15
16#include <lagrange/Attribute.h>
17#include <lagrange/AttributeValueType.h>
18#include <lagrange/Logger.h>
19#include <lagrange/internal/string_from_scalar.h>
20#include <lagrange/python/binding.h>
21#include <lagrange/python/tensor_utils.h>
22#include <lagrange/utils/Error.h>
23#include <lagrange/utils/assert.h>
24#include <lagrange/utils/fmt/format.h>
25#include <lagrange/utils/invalid.h>
26
27namespace lagrange::python {
28
29void bind_attribute(nanobind::module_& m)
30{
31 namespace nb = nanobind;
32 using namespace nb::literals;
33
34 auto attr_class = nb::class_<PyAttribute>(
35 m,
36 "Attribute",
37 "Attribute data associated with mesh elements (vertices, facets, corners, edges).");
38 attr_class.def_prop_ro(
39 "element_type",
40 [](PyAttribute& self) { return self->get_element_type(); },
41 "Element type (Vertex, Facet, Corner, Edge, Value).");
42 attr_class.def_prop_ro(
43 "usage",
44 [](PyAttribute& self) { return self->get_usage(); },
45 "Usage type (Position, Normal, UV, Color, etc.).");
46 attr_class.def_prop_ro(
47 "num_channels",
48 [](PyAttribute& self) { return self->get_num_channels(); },
49 "Number of channels per element.");
50
51 attr_class.def_prop_rw(
52 "default_value",
53 [](PyAttribute& self) {
54 return self.process([](auto& attr) { return nb::cast(attr.get_default_value()); });
55 },
56 [](PyAttribute& self, double val) {
57 self.process([&](auto& attr) {
58 using ValueType = typename std::decay_t<decltype(attr)>::ValueType;
59 attr.set_default_value(static_cast<ValueType>(val));
60 });
61 },
62 "Default value for new elements.");
63 attr_class.def_prop_rw(
64 "growth_policy",
65 [](PyAttribute& self) {
66 return self.process([](auto& attr) { return attr.get_growth_policy(); });
67 },
68 [](PyAttribute& self, AttributeGrowthPolicy policy) {
69 self.process([&](auto& attr) { attr.set_growth_policy(policy); });
70 },
71 "Policy for growing the attribute when elements are added.");
72 attr_class.def_prop_rw(
73 "shrink_policy",
74 [](PyAttribute& self) {
75 return self.process([](auto& attr) { return attr.get_shrink_policy(); });
76 },
77 [](PyAttribute& self, AttributeShrinkPolicy policy) {
78 self.process([&](auto& attr) { attr.set_shrink_policy(policy); });
79 },
80 "Policy for shrinking the attribute when elements are removed.");
81 attr_class.def_prop_rw(
82 "write_policy",
83 [](PyAttribute& self) {
84 return self.process([](auto& attr) { return attr.get_write_policy(); });
85 },
86 [](PyAttribute& self, AttributeWritePolicy policy) {
87 self.process([&](auto& attr) { attr.set_write_policy(policy); });
88 },
89 "Policy for write operations on the attribute.");
90 attr_class.def_prop_rw(
91 "copy_policy",
92 [](PyAttribute& self) {
93 return self.process([](auto& attr) { return attr.get_copy_policy(); });
94 },
95 [](PyAttribute& self, AttributeCopyPolicy policy) {
96 self.process([&](auto& attr) { attr.set_copy_policy(policy); });
97 },
98 "Policy for copying the attribute.");
99 attr_class.def_prop_rw(
100 "cast_policy",
101 [](PyAttribute& self) {
102 return self.process([](auto& attr) { return attr.get_cast_policy(); });
103 },
104 [](PyAttribute& self, AttributeCastPolicy policy) {
105 self.process([&](auto& attr) { attr.set_cast_policy(policy); });
106 },
107 "Policy for casting the attribute to different types.");
108 attr_class.def(
109 "create_internal_copy",
110 [](PyAttribute& self) { self.process([](auto& attr) { attr.create_internal_copy(); }); },
111 "Create an internal copy if the attribute wraps external data.");
112 attr_class.def(
113 "clear",
114 [](PyAttribute& self) { self.process([](auto& attr) { attr.clear(); }); },
115 "Remove all elements from the attribute.");
116 attr_class.def(
117 "reserve_entries",
118 [](PyAttribute& self, size_t s) {
119 self.process([=](auto& attr) { attr.reserve_entries(s); });
120 },
121 "num_entries"_a,
122 R"(Reserve enough memory for `num_entries` entries.
123
124:param num_entries: Number of entries to reserve. It does not need to be a multiple of `num_channels`.)");
125 attr_class.def(
126 "insert_elements",
127 [](PyAttribute& self, size_t num_elements) {
128 self.process([=](auto& attr) { attr.insert_elements(num_elements); });
129 },
130 "num_elements"_a,
131 R"(Insert new elements with default value to the attribute.
132
133:param num_elements: Number of elements to insert.)");
134 attr_class.def(
135 "insert_elements",
136 [](PyAttribute& self, nb::object value) {
137 auto insert_elements = [&](auto& attr) {
138 using ValueType = typename std::decay_t<decltype(attr)>::ValueType;
139 GenericTensor tensor;
140 std::vector<ValueType> buffer;
141 if (nb::try_cast(value, tensor)) {
142 if (tensor.dtype() != nb::dtype<ValueType>()) {
143 throw nb::type_error(
144 lagrange::format(
145 "Tensor has a unexpected dtype. Expecting {}.",
147 .c_str());
148 }
149 Tensor<ValueType> local_tensor(tensor.handle());
150 auto [data, shape, stride] = tensor_to_span(local_tensor);
151 la_runtime_assert(is_dense(shape, stride));
152 attr.insert_elements(data);
153 } else if (nb::try_cast(value, buffer)) {
154 attr.insert_elements({buffer.data(), buffer.size()});
155 } else {
156 throw nb::type_error("Argument `value` must be either list or np.ndarray.");
157 }
158 };
159 self.process(insert_elements);
160 },
161 "tensor"_a,
162 R"(Insert new elements to the attribute.
163
164:param tensor: A tensor with shape (num_elements, num_channels) or (num_elements,).)");
165 attr_class.def(
166 "empty",
167 [](PyAttribute& self) { return self.process([](auto& attr) { return attr.empty(); }); },
168 "Check if the attribute has no elements.");
169 attr_class.def_prop_ro(
170 "num_elements",
171 [](PyAttribute& self) {
172 return self.process([](auto& attr) { return attr.get_num_elements(); });
173 },
174 "Number of elements in the attribute.");
175 attr_class.def_prop_ro(
176 "external",
177 [](PyAttribute& self) {
178 return self.process([](auto& attr) { return attr.is_external(); });
179 },
180 "Check if the attribute wraps external data.");
181 attr_class.def_prop_ro(
182 "readonly",
183 [](PyAttribute& self) {
184 return self.process([](auto& attr) { return attr.is_read_only(); });
185 },
186 "Check if the attribute is read-only.");
187 attr_class.def_prop_rw(
188 "data",
189 [](PyAttribute& self) {
190 return self.process([&](auto& attr) {
191 auto tensor = attribute_to_tensor(attr, nb::find(&self));
192 return nb::cast(tensor, nb::rv_policy::reference_internal);
193 });
194 },
195 [](PyAttribute& self, nb::object value) {
196 auto wrap_tensor = [&](auto& attr) {
197 using ValueType = typename std::decay_t<decltype(attr)>::ValueType;
198 GenericTensor tensor;
199 std::vector<ValueType> buffer;
200 if (nb::try_cast(value, tensor)) {
201 if (tensor.dtype() != nb::dtype<ValueType>()) {
202 throw nb::type_error(
203 lagrange::format(
204 "Tensor has a unexpected dtype. Expecting {}.",
206 .c_str());
207 }
208 Tensor<ValueType> local_tensor(tensor.handle());
209 auto [data, shape, stride] = tensor_to_span(local_tensor);
210 la_runtime_assert(is_dense(shape, stride));
211 la_runtime_assert(shape.size() == 1 ? 1 == attr.get_num_channels() : true);
213 shape.size() == 2 ? shape[1] == attr.get_num_channels() : true);
214 const size_t num_elements = shape[0];
215
216 auto owner = std::make_shared<nb::object>(nb::find(tensor));
217 attr.wrap(make_shared_span(owner, data.data(), data.size()), num_elements);
218 } else if (nb::try_cast(value, buffer)) {
219 attr.clear();
220 attr.insert_elements({buffer.data(), buffer.size()});
221 } else {
222 throw nb::type_error("Attribute.data must be either list or np.ndarray.");
223 }
224 };
225 self.process(wrap_tensor);
226 },
227 nb::for_getter(nb::sig("def data(self, /) -> numpy.typing.NDArray")),
228 "Raw data as a numpy array.");
229 attr_class.def_prop_ro(
230 "dtype",
231 [](PyAttribute& self) -> std::optional<nb::type_object> {
232 auto np = nb::module_::import_("numpy");
233 switch (self.ptr()->get_value_type()) {
234 case AttributeValueType::e_int8_t: return np.attr("int8");
235 case AttributeValueType::e_int16_t: return np.attr("int16");
236 case AttributeValueType::e_int32_t: return np.attr("int32");
237 case AttributeValueType::e_int64_t: return np.attr("int64");
238 case AttributeValueType::e_uint8_t: return np.attr("uint8");
239 case AttributeValueType::e_uint16_t: return np.attr("uint16");
240 case AttributeValueType::e_uint32_t: return np.attr("uint32");
241 case AttributeValueType::e_uint64_t: return np.attr("uint64");
242 case AttributeValueType::e_float: return np.attr("float32");
243 case AttributeValueType::e_double: return np.attr("float64");
244 default: logger().warn("Attribute has an unknown dtype."); return std::nullopt;
245 }
246 },
247 "NumPy dtype of the attribute values.");
248}
249
250} // namespace lagrange::python
Definition PyAttribute.h:24
LA_CORE_API spdlog::logger & logger()
Retrieves the current logger.
Definition Logger.cpp:40
AttributeShrinkPolicy
Policy for shrinking external attribute buffers.
Definition AttributeFwd.h:117
AttributeGrowthPolicy
Policy for growing external attribute buffers.
Definition AttributeFwd.h:96
AttributeCopyPolicy
Policy for copying attribute that are views onto external buffers.
Definition AttributeFwd.h:161
AttributeWritePolicy
Policy for attempting to write to read-only external buffers.
Definition AttributeFwd.h:138
AttributeCastPolicy
Policy for remapping invalid values when casting to a different value type.
Definition AttributeFwd.h:182
#define la_runtime_assert(...)
Runtime assertion check.
Definition assert.h:175
SharedSpan< T > make_shared_span(const std::shared_ptr< Y > &r, T *element_ptr, size_t size)
Created a SharedSpan object around an internal buffer of a parent object.
Definition SharedSpan.h:101
std::string_view string_from_scalar()
Returns a human-readable string from any supported attribute value type.