diff --git a/include/pybind11/functional.h b/include/pybind11/functional.h index 8a8c32c0ec..4baeaa57a2 100644 --- a/include/pybind11/functional.h +++ b/include/pybind11/functional.h @@ -14,6 +14,7 @@ #include "pybind11.h" #include +#include PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) PYBIND11_NAMESPACE_BEGIN(detail) @@ -129,24 +130,23 @@ struct type_caster> { // See PR #1413 for full details } else { // Check number of arguments of Python function - auto argCountFromFuncCode = [&](handle &obj) { - // This is faster then doing import inspect and - // inspect.signature(obj).parameters - - object argCount = obj.attr("co_argcount"); - return argCount.template cast(); + auto get_argument_count = [](const handle &obj) -> size_t { + // Faster then `import inspect` and `inspect.signature(obj).parameters` + return obj.attr("co_argcount").cast(); }; size_t argCount = 0; - handle codeAttr = PyObject_GetAttrString(src.ptr(), "__code__"); + handle empty; + object codeAttr = getattr(src, "__code__", empty); + if (codeAttr) { - argCount = argCountFromFuncCode(codeAttr); + argCount = get_argument_count(codeAttr); } else { - handle callAttr = PyObject_GetAttrString(src.ptr(), "__call__"); + object callAttr = getattr(src, "__call__", empty); + if (callAttr) { - handle codeAttr2 = PyObject_GetAttrString(callAttr.ptr(), "__code__"); - argCount = argCountFromFuncCode(codeAttr2) - - 1; // we have to remove the self argument + object codeAttr2 = getattr(callAttr, "__code__"); + argCount = get_argument_count(codeAttr2) - 1; // removing the self argument } else { // No __code__ or __call__ attribute, this is not a proper Python function return false; diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index d2afbc2ca3..c81aee6672 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -112,7 +112,8 @@ def __call__(self, a): return a assert m.dummy_function_overloaded_std_func_arg(f) == 9 - assert m.dummy_function_overloaded_std_func_arg(A()) == 9 + a = A() + assert m.dummy_function_overloaded_std_func_arg(a) == 9 assert m.dummy_function_overloaded_std_func_arg(lambda i: i) == 9 def f2(a, b):