14#include <lagrange/image/Array3D.h>
15#include <lagrange/image/View3D.h>
16#include <lagrange/python/binding.h>
17#include <lagrange/python/tensor_utils.h>
18#include <lagrange/utils/assert.h>
20namespace lagrange::python {
22namespace nb = nanobind;
24using ImageShape = nb::shape<-1, -1, -1>;
26template <
typename Scalar>
27using ImageTensor = nb::ndarray<Scalar, ImageShape, nb::numpy, nb::c_contig, nb::device::cpu>;
32template <
typename Scalar>
33auto tensor_to_image_view(
const ImageTensor<Scalar>& tensor) -> image::experimental::View3D<Scalar>
35 const image::experimental::dextents<size_t, 3> shape{
40 const std::array<size_t, 3> strides{
41 static_cast<size_t>(tensor.stride(1)),
42 static_cast<size_t>(tensor.stride(0)),
43 static_cast<size_t>(tensor.stride(2)),
45 const image::experimental::layout_stride::mapping mapping{shape, strides};
46 image::experimental::View3D<Scalar> view{
47 static_cast<float*
>(tensor.data()),
53template <
typename Scalar>
54void copy_tensor_to_image_view(
55 const ImageTensor<Scalar>& tensor,
56 image::experimental::View3D<Scalar> image)
58 const auto width =
static_cast<unsigned int>(tensor.shape(1));
59 const auto height =
static_cast<unsigned int>(tensor.shape(0));
60 const auto num_channels =
static_cast<unsigned int>(tensor.shape(2));
62 image.extent(0) == width && image.extent(1) == height && image.extent(2) == num_channels,
63 "Tensor and mdspan dimensions do not match");
64 for (
unsigned int j = 0; j < height; j++) {
65 for (
unsigned int i = 0; i < width; i++) {
66 for (
unsigned int c = 0; c < num_channels; c++) {
67 image(i, j, c) = tensor(j, i, c);
73template <
typename Scalar>
74nb::object image_array_to_tensor(
const image::experimental::Array3D<Scalar>& image_)
76 auto image =
const_cast<image::experimental::Array3D<Scalar>&
>(image_);
77 auto tensor = Tensor<float>(
78 static_cast<float*
>(image.data()),
86 static_cast<int64_t>(image.stride(1)),
87 static_cast<int64_t>(image.stride(0)),
88 static_cast<int64_t>(image.stride(2)),
#define la_runtime_assert(...)
Runtime assertion check.
Definition assert.h:174