Skip to content

Commit

Permalink
add argument number dispatch mechanism for std::function casting
Browse files Browse the repository at this point in the history
  • Loading branch information
rath3t committed Aug 3, 2024
1 parent 50acb81 commit ffed130
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 2 deletions.
41 changes: 40 additions & 1 deletion include/pybind11/functional.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,17 @@ struct type_caster<std::function<Return(Args...)>> {
if (detail::is_function_record_capsule(c)) {
rec = c.get_pointer<function_record>();
}

while (rec != nullptr) {
const int correctingSelfArgument = rec->is_method ? 1:0;
if(rec->nargs - correctingSelfArgument != sizeof...(Args)) {
rec = rec->next;
// if the overload is not feasible in terms of number of arguments, we continue to the next one.
// If there is no next one, we return false.
if(rec == nullptr) {
return false;
}
continue;
}
if (rec->is_stateless
&& same_type(typeid(function_type),
*reinterpret_cast<const std::type_info *>(rec->data[1]))) {
Expand All @@ -76,6 +85,36 @@ struct type_caster<std::function<Return(Args...)>> {
// Raising an fail exception here works to prevent the segfault, but only on gcc.
// See PR #1413 for full details
}
else {
// Check number of arguments of Python function
auto getArgCount = [&](PyObject *obj) {
// This is faster then doing import inspect and inspect.signature(obj).parameters
auto* t = PyObject_GetAttrString(obj, "__code__");
auto* argCount = PyObject_GetAttrString(t, "co_argcount");
return PyLong_AsLong(argCount);
};
long argCount = -1;

if(static_cast<bool>(PyObject_HasAttrString(src.ptr(), "__code__"))) {
argCount= getArgCount(src.ptr());
}else {
if(static_cast<bool>(PyObject_HasAttrString(src.ptr(), "__call__")))
{
auto* t2 = PyObject_GetAttrString(src.ptr(), "__call__");
argCount = getArgCount(t2)-1; // we have to remove the self argument
}else {
// No __code__ or __call__ attribute, this is not a proper Python function
return false;
}
}
// if we are a method, we have to correct the argument count since we are not counting the self argument
const int correctingSelfArgument = static_cast<bool>(PyMethod_Check(src.ptr()))? 1:0;

argCount-=correctingSelfArgument;
if (argCount != sizeof...(Args)) {
return false;
}
}

// ensure GIL is held during functor destruction
struct func_handle {
Expand Down
4 changes: 4 additions & 0 deletions tests/test_callbacks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@ TEST_SUBMODULE(callbacks, m) {
return "argument does NOT match dummy_function. This should never happen!";
});

// test_cpp_correct_overload_resolution
m.def("dummy_function_overloaded_std_func_arg", [](std::function<int(int)> f) { return 3*f(3); });
m.def("dummy_function_overloaded_std_func_arg", [](std::function<int(int,int)> f) { return 2*f(3,4); });

class AbstractBase {
public:
// [workaround(intel)] = default does not work here
Expand Down
13 changes: 12 additions & 1 deletion tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,17 @@ def test_cpp_callable_cleanup():
alive_counts = m.test_cpp_callable_cleanup()
assert alive_counts == [0, 1, 2, 1, 2, 1, 0]

def test_cpp_correct_overload_resolution():
def f(a):
return a
assert(m.dummy_function_overloaded_std_func_arg(f)==9)
assert(m.dummy_function_overloaded_std_func_arg(lambda i: i)==9)

def f2(a,b):
return a+b
assert(m.dummy_function_overloaded_std_func_arg(f2)==14)
assert(m.dummy_function_overloaded_std_func_arg(lambda i,j: i+j)==14)


def test_cpp_function_roundtrip():
"""Test if passing a function pointer from C++ -> Python -> C++ yields the original pointer"""
Expand Down Expand Up @@ -130,7 +141,7 @@ def test_cpp_function_roundtrip():
m.test_dummy_function(lambda x, y: x + y)
assert any(
s in str(excinfo.value)
for s in ("missing 1 required positional argument", "takes exactly 2 arguments")
for s in ("incompatible function arguments. The following argument types are", "function test_cpp_function_roundtrip.<locals>.<lambda>")
)


Expand Down
16 changes: 16 additions & 0 deletions tests/test_embed/test_interpreter.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <pybind11/embed.h>
#include <pybind11/functional.h>

// Silence MSVC C++17 deprecation warning from Catch regarding std::uncaught_exceptions (up to
// catch 2.0.1; this should be fixed in the next catch release after 2.0.1).
Expand Down Expand Up @@ -78,6 +79,12 @@ PYBIND11_EMBEDDED_MODULE(throw_error_already_set, ) {
d["missing"].cast<py::object>();
}

PYBIND11_EMBEDDED_MODULE(func_module, m) {
m.def("funcOverload", [](const std::function<int(int, int)> f) {
return f(2, 3);})
.def("funcOverload", [](const std::function<int(int)> f) { return f(2); });
}

TEST_CASE("PYTHONPATH is used to update sys.path") {
// The setup for this TEST_CASE is in catch.cpp!
auto sys_path = py::str(py::module_::import("sys").attr("path")).cast<std::string>();
Expand Down Expand Up @@ -171,6 +178,15 @@ TEST_CASE("There can be only one interpreter") {
py::initialize_interpreter();
}

TEST_CASE("Check the overload resolution from cpp_function objects to std::function") {
auto m = py::module_::import("func_module");
auto f = std::function<int(int)>([](int x) { return 2 * x; });
REQUIRE(m.attr("funcOverload")(f).template cast<int>() == 4);

auto f2 = std::function<int(int,int)>([](int x, int y) { return 2 * x * y; });
REQUIRE(m.attr("funcOverload")(f2).template cast<int>() == 12);
}

#if PY_VERSION_HEX >= PYBIND11_PYCONFIG_SUPPORT_PY_VERSION_HEX
TEST_CASE("Custom PyConfig") {
py::finalize_interpreter();
Expand Down

0 comments on commit ffed130

Please sign in to comment.