diff --git a/include/pybind11/buffer_info.h b/include/pybind11/buffer_info.h index 06120d5563..f8689babf3 100644 --- a/include/pybind11/buffer_info.h +++ b/include/pybind11/buffer_info.h @@ -37,6 +37,9 @@ inline std::vector f_strides(const std::vector &shape, ssize_t return strides; } +template +struct compare_buffer_info; + PYBIND11_NAMESPACE_END(detail) /// Information record describing a Python buffer object @@ -150,6 +153,11 @@ struct buffer_info { Py_buffer *view() const { return m_view; } Py_buffer *&view() { return m_view; } + template + static bool compare(const buffer_info &b) { + return detail::compare_buffer_info::compare(b); + } + private: struct private_ctr_tag {}; @@ -170,7 +178,7 @@ struct buffer_info { PYBIND11_NAMESPACE_BEGIN(detail) -template +template struct compare_buffer_info { static bool compare(const buffer_info &b) { return b.format == format_descriptor::format() && b.itemsize == (ssize_t) sizeof(T); diff --git a/tests/test_buffers.cpp b/tests/test_buffers.cpp index ed9013ae7b..ab727552f4 100644 --- a/tests/test_buffers.cpp +++ b/tests/test_buffers.cpp @@ -16,7 +16,7 @@ TEST_SUBMODULE(buffers, m) { m.attr("std_is_same_double_long_double") = std::is_same::value; - m.def("format_descriptor_format_compare", + m.def("format_descriptor_format_buffer_info_compare", [](const std::string &cpp_name, const py::buffer &buffer) { // https://google.github.io/styleguide/cppguide.html#Static_and_Global_Variables static auto *format_table = new std::map; @@ -25,7 +25,7 @@ TEST_SUBMODULE(buffers, m) { if (format_table->empty()) { #define PYBIND11_ASSIGN_HELPER(...) \ (*format_table)[#__VA_ARGS__] = py::format_descriptor<__VA_ARGS__>::format(); \ - (*compare_table)[#__VA_ARGS__] = py::detail::compare_buffer_info<__VA_ARGS__>::compare; + (*compare_table)[#__VA_ARGS__] = py::buffer_info::compare<__VA_ARGS__>; PYBIND11_ASSIGN_HELPER(PyObject *) PYBIND11_ASSIGN_HELPER(bool) PYBIND11_ASSIGN_HELPER(std::int8_t) diff --git a/tests/test_buffers.py b/tests/test_buffers.py index 108154231a..4a60ea4261 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -48,10 +48,10 @@ @pytest.mark.parametrize(("cpp_name", "np_dtype"), CPP_NAME_NP_DTYPE_TABLE) -def test_format_descriptor_format_compare(cpp_name, np_dtype): +def test_format_descriptor_format_buffer_info_compare(cpp_name, np_dtype): np_array = np.array([], dtype=np_dtype) for other_cpp_name, expected_format in CPP_NAME_FORMAT_TABLE: - format, np_array_is_matching = m.format_descriptor_format_compare( + format, np_array_is_matching = m.format_descriptor_format_buffer_info_compare( other_cpp_name, np_array ) assert format == expected_format