Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,16 @@ message(STATUS "Found numpy: ${NUMPY_INCLUDE_DIRS}")
# =====

set(XTENSOR_PYTHON_HEADERS
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pyarray.hpp
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pyarray_backstrides.hpp
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pycontainer.hpp
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pystrides_adaptor.hpp
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pytensor.hpp
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pyvectorize.hpp
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/xtensor_python_config.hpp
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/xtensor_type_caster_base.hpp
)
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pyarray.hpp
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pyarray_backstrides.hpp
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pycontainer.hpp
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pynative_casters.hpp
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pystrides_adaptor.hpp
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pytensor.hpp
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pyvectorize.hpp
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/xtensor_python_config.hpp
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/xtensor_type_caster_base.hpp
)

add_library(xtensor-python INTERFACE)
target_include_directories(xtensor-python INTERFACE
Expand Down
6 changes: 1 addition & 5 deletions include/xtensor-python/pyarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "pyarray_backstrides.hpp"
#include "pycontainer.hpp"
#include "pystrides_adaptor.hpp"
#include "pynative_casters.hpp"
#include "xtensor_type_caster_base.hpp"

namespace xt
Expand Down Expand Up @@ -91,11 +92,6 @@ namespace pybind11
}
};

// Type caster for casting xarray to ndarray
template <class T, xt::layout_type L>
struct type_caster<xt::xarray<T, L>> : xtensor_type_caster_base<xt::xarray<T, L>>
{
};
}
}

Expand Down
52 changes: 52 additions & 0 deletions include/xtensor-python/pynative_casters.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/***************************************************************************
* Copyright (c) Wolf Vollprecht, Johan Mabille and Sylvain Corlay *
* Copyright (c) QuantStack *
* *
* Distributed under the terms of the BSD 3-Clause License. *
* *
* The full license is in the file LICENSE, distributed with this software. *
****************************************************************************/

#ifndef PYNATIVE_CASTERS_HPP
#define PYNATIVE_CASTERS_HPP

#include "xtensor_type_caster_base.hpp"


namespace pybind11
{
namespace detail
{
// Type caster for casting xarray to ndarray
template <class T, xt::layout_type L>
struct type_caster<xt::xarray<T, L>> : xtensor_type_caster_base<xt::xarray<T, L>>
{
};

// Type caster for casting xt::xtensor to ndarray
template <class T, std::size_t N, xt::layout_type L>
struct type_caster<xt::xtensor<T, N, L>> : xtensor_type_caster_base<xt::xtensor<T, N, L>>
{
};

// Type caster for casting xt::xstrided_view to ndarray
template <class CT, class S, xt::layout_type L, class FST>
struct type_caster<xt::xstrided_view<CT, S, L, FST>> : xtensor_type_caster_base<xt::xstrided_view<CT, S, L, FST>>
{
};

// Type caster for casting xt::xarray_adaptor to ndarray
template <class EC, xt::layout_type L, class SC, class Tag>
struct type_caster<xt::xarray_adaptor<EC, L, SC, Tag>> : xtensor_type_caster_base<xt::xarray_adaptor<EC, L, SC, Tag>>
{
};

// Type caster for casting xt::xtensor_adaptor to ndarray
template <class EC, std::size_t N, xt::layout_type L, class Tag>
struct type_caster<xt::xtensor_adaptor<EC, N, L, Tag>> : xtensor_type_caster_base<xt::xtensor_adaptor<EC, N, L, Tag>>
{
};
}
}

#endif
6 changes: 1 addition & 5 deletions include/xtensor-python/pytensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include "pycontainer.hpp"
#include "pystrides_adaptor.hpp"
#include "pynative_casters.hpp"
#include "xtensor_type_caster_base.hpp"

namespace xt
Expand Down Expand Up @@ -99,11 +100,6 @@ namespace pybind11
}
};

// Type caster for casting xt::xtensor to ndarray
template <class T, std::size_t N, xt::layout_type L>
struct type_caster<xt::xtensor<T, N, L>> : xtensor_type_caster_base<xt::xtensor<T, N, L>>
{
};
}
}

Expand Down
12 changes: 6 additions & 6 deletions include/xtensor-python/xtensor_type_caster_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace pybind11
{
namespace detail
{
// Casts an xtensor (or xarray) type to numpy array.If given a base,
// Casts a strided expression type to numpy array.If given a base,
// the numpy array references the src data, otherwise it'll make a copy.
// The writeable attributes lets you specify writeable flag for the array.
template <typename Type>
Expand All @@ -39,7 +39,7 @@ namespace pybind11
std::vector<std::size_t> python_shape(src.shape().size());
std::copy(src.shape().begin(), src.shape().end(), python_shape.begin());

array a(python_shape, python_strides, src.begin(), base);
array a(python_shape, python_strides, &*(src.begin()), base);

if (!writeable)
{
Expand All @@ -49,8 +49,8 @@ namespace pybind11
return a.release();
}

// Takes an lvalue ref to some xtensor (or xarray) type and a (python) base object, creating a numpy array that
// reference the xtensor object's data with `base` as the python-registered base class (if omitted,
// Takes an lvalue ref to some strided expression type and a (python) base object, creating a numpy array that
// reference the expression object's data with `base` as the python-registered base class (if omitted,
// the base will be set to None, and lifetime management is up to the caller). The numpy array is
// non-writeable if the given type is const.
template <typename Type, typename CType>
Expand All @@ -59,7 +59,7 @@ namespace pybind11
return xtensor_array_cast<Type>(src, parent, !std::is_const<CType>::value);
}

// Takes a pointer to xtensor (or xarray), builds a capsule around it, then returns a numpy
// Takes a pointer to a strided expression, builds a capsule around it, then returns a numpy
// array that references the encapsulated data with a python-side reference to the capsule to tie
// its destruction to that of any dependent python objects. Const-ness is determined by whether or
// not the CType of the pointer given is const.
Expand All @@ -70,7 +70,7 @@ namespace pybind11
return xtensor_ref_array<Type>(*src, base);
}

// Base class of type_caster for xtensor and xarray
// Base class of type_caster for strided expressions
template <class Type>
struct xtensor_type_caster_base
{
Expand Down
56 changes: 56 additions & 0 deletions test_python/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#include "xtensor-python/pyarray.hpp"
#include "xtensor-python/pytensor.hpp"
#include "xtensor-python/pyvectorize.hpp"
#include "xtensor/xadapt.hpp"
#include "xtensor/xstrided_view.hpp"

namespace py = pybind11;
using complex_t = std::complex<double>;
Expand Down Expand Up @@ -133,6 +135,49 @@ class C
array_type m_array;
};

struct test_native_casters
{
using array_type = xt::xarray<double>;
array_type a = xt::ones<double>({50, 50});

const auto & get_array()
{
return a;
}

auto get_strided_view()
{
return xt::strided_view(a, {xt::range(0, 1), xt::range(0, 3, 2)});
}

auto get_array_adapter()
{
using shape_type = std::vector<size_t>;
shape_type shape = {2, 2};
shape_type stride = {3, 2};
return xt::adapt(a.data(), 4, xt::no_ownership(), shape, stride);
}

auto get_tensor_adapter()
{
using shape_type = std::array<size_t, 2>;
shape_type shape = {2, 2};
shape_type stride = {3, 2};
return xt::adapt(a.data(), 4, xt::no_ownership(), shape, stride);
}

auto get_owning_array_adapter()
{
size_t size = 100;
int * data = new int[size];
std::fill(data, data + size, 1);

using shape_type = std::vector<size_t>;
shape_type shape = {size};
return xt::adapt(std::move(data), size, xt::acquire_ownership(), shape);
}
};

xt::pyarray<A> dtype_to_python()
{
A a1{123, 321, 'a', {1, 2, 3}};
Expand Down Expand Up @@ -257,4 +302,15 @@ PYBIND11_MODULE(xtensor_python_test, m)

m.def("diff_shape_overload", [](xt::pytensor<int, 1> a) { return 1; });
m.def("diff_shape_overload", [](xt::pytensor<int, 2> a) { return 2; });

py::class_<test_native_casters>(m, "test_native_casters")
.def(py::init<>())
.def("get_array", &test_native_casters::get_array, py::return_value_policy::reference_internal) // memory managed by the class instance
.def("get_strided_view", &test_native_casters::get_strided_view, py::keep_alive<0, 1>()) // keep_alive<0, 1>() => do not free "self" before the returned view
.def("get_array_adapter", &test_native_casters::get_array_adapter, py::keep_alive<0, 1>()) // keep_alive<0, 1>() => do not free "self" before the returned adapter
.def("get_tensor_adapter", &test_native_casters::get_tensor_adapter, py::keep_alive<0, 1>()) // keep_alive<0, 1>() => do not free "self" before the returned adapter
.def("get_owning_array_adapter", &test_native_casters::get_owning_array_adapter) // auto memory management as the adapter owns its memory
.def("view_keep_alive_member_function", [](test_native_casters & self, xt::pyarray<double> & a) // keep_alive<0, 2>() => do not free second parameter before the returned view
{return xt::reshape_view(a, {a.size(), });},
py::keep_alive<0, 2>());
}
67 changes: 67 additions & 0 deletions test_python/test_pyarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,73 @@ def test_diff_shape_overload(self):
# FIXME: the TypeError information is not informative
xt.diff_shape_overload(np.ones((2, 2, 2)))

def test_native_casters(self):
import gc

# check keep alive policy for get_strided_view()
gc.collect()
obj = xt.test_native_casters()
a = obj.get_strided_view()
obj = None
gc.collect()
_ = np.zeros((100, 100))
self.assertEqual(a.sum(), a.size)

# check keep alive policy for get_array_adapter()
gc.collect()
obj = xt.test_native_casters()
a = obj.get_array_adapter()
obj = None
gc.collect()
_ = np.zeros((100, 100))
self.assertEqual(a.sum(), a.size)

# check keep alive policy for get_array_adapter()
gc.collect()
obj = xt.test_native_casters()
a = obj.get_tensor_adapter()
obj = None
gc.collect()
_ = np.zeros((100, 100))
self.assertEqual(a.sum(), a.size)

# check keep alive policy for get_owning_array_adapter()
gc.collect()
obj = xt.test_native_casters()
a = obj.get_owning_array_adapter()
gc.collect()
_ = np.zeros((100, 100))
self.assertEqual(a.sum(), a.size)

# check keep alive policy for view_keep_alive_member_function()
gc.collect()
a = np.ones((100, 100))
b = obj.view_keep_alive_member_function(a)
obj = None
a = None
gc.collect()
_ = np.zeros((100, 100))
self.assertEqual(b.sum(), b.size)

# check shared buffer (insure that no copy is done)
obj = xt.test_native_casters()
arr = obj.get_array()

strided_view = obj.get_strided_view()
strided_view[0, 1] = -1
self.assertEqual(strided_view.shape, (1, 2))
self.assertEqual(arr[0, 2], -1)

adapter = obj.get_array_adapter()
self.assertEqual(adapter.shape, (2, 2))
adapter[1, 1] = -2
self.assertEqual(arr[0, 5], -2)

adapter = obj.get_tensor_adapter()
self.assertEqual(adapter.shape, (2, 2))
adapter[1, 1] = -3
self.assertEqual(arr[0, 5], -3)


class AttributeTest(TestCase):

Expand Down