Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I have two overloaded functions, one of which uses a template, how should I convert to the Python side? #3085

Closed
sun1638650145 opened this issue Jul 9, 2021 · 6 comments

Comments

@sun1638650145
Copy link

This is the class I defined.

namespace initializers {
class RandomNormal: public Initializer {
    public:
        RandomNormal();
        explicit RandomNormal(std::string name);
        explicit RandomNormal(std::string name, std::optional<unsigned int> seed);

        // overload
        template<typename Matrix, typename IDtype, typename fDtype>
        Matrix PyCall(const IDtype &attributes_or_structure);
  
        std::map<std::string, Eigen::MatrixXd> PyCall(const Eigen::RowVectorXi &attributes_or_structure);
};
}  // namespace initializers

This conversion is not successful, it seems that pybind11::overload_cast does not work.

    pybind11::class_<initializers::RandomNormal, initializers::Initializer>(m, "RandomNormal")
        .def(pybind11::init())
        .def(pybind11::init<std::string>(), pybind11::arg("name"))
        .def(pybind11::init<std::string, std::optional<unsigned int>>(),
             pybind11::arg("name")="random_normal",
             pybind11::arg("seed")=pybind11::none())
        .def_readwrite("name", &initializers::RandomNormal::name)
        .def_readwrite("seed", &initializers::RandomNormal::seed)
        .def("__call__", &initializers::RandomNormal::PyCall<Eigen::MatrixXd, int, double>, pybind11::arg("attributes_or_structure"))
        .def("__call__", &initializers::RandomNormal::PyCall<Eigen::MatrixXf, int, float>, 
        .def("__call__", pybind11::overload_cast<const Eigen::RowVectorXi &>(&initializers::RandomNormal::PyCall), pybind11::arg("attributes_or_structure"));

This is the error log.

error: no matching function for call to object of type 'const detail::overload_cast_impl<const Matrix<int, 1, -1, 1> &>'
@sun1638650145
Copy link
Author

I feel that the previous example is too complicated, so I code a simplified example.

#include <iostream>
#include <vector>
#include "pybind11/pybind11.h"

template <typename D>
D sum(const D &a, const D &b) {
    return a + b;
}

std::vector<int> sum(const int &a, const int &b) {
    std::vector<int> list;
    list.push_back(a);
    list.push_back(b);

    return list;
}

PYBIND11_MODULE(mymath, m) {
    m.def("sum", &sum<int>);
    m.def("sum", &sum<double>);
    m.def("sum", pybind11::overload_cast<const int &, const int &>(&sum));
}

@jiwaszki
Copy link
Contributor

jiwaszki commented Jul 9, 2021

Hi @sun1638650145 ,

First thing please look at the other comment I (on other account) made about overloading functions: #3035
When you uncomment first sum registration, it will use it instead of the one you probably want here.

Don't forget you are always able to use lambdas to resolve issues with overloading. This is way simpler solution and clean as well (in my opinion).

One more thing is to use pybind11/stl.h so you are able to return Python containers from functions.

Please try this out and let me know how it works for you! 😄

#include <iostream>
#include <vector>
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"

template <typename D>
D sum(const D &a, const D &b)
{
    return a + b;
}

std::vector<int> sum(const int &a, const int &b)
{
    std::vector<int> list;
    list.push_back(a);
    list.push_back(b);

    return list;
}

PYBIND11_MODULE(mymodule, m)
{
    // m.def("sum", &sum<int>);
    m.def("sum", &sum<double>);
    m.def("sum", [](const int &a, const int &b)
          { return sum(a, b); });
}

@sun1638650145
Copy link
Author

sun1638650145 commented Jul 10, 2021

@jiwaszki First of all, thank you very much for your help. However, because of my reasons, the previous examples cannot explain my problem very well. Using Lambda can only solve the problem of static functions, and the member functions of the class seem to be infeasible. I code a new example again. Finally, thank you again.

#include "pybind11/eigen.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"


class Foo {
public:
    unsigned int seed = 1234;

    template<typename T>
    std::variant<int, std::vector<Eigen::MatrixXf>>
    sum (const T &a, const T &b) {
        if (typeid(a) == typeid(int)) {
            return a + b;
        }
        if (typeid(a) == typeid(Eigen::MatrixXf)) {
            std::vector<T> list;
            list.push_back(this->seed * a);
            list.push_back(this->seed * b);

            return list;
        }
    }
};


PYBIND11_MODULE(mymath, m) {
    pybind11::class_<Foo>(m, "Foo")
            .def(pybind11::init())
            .def("sum", pybind11::overload_cast<const int &, const int &>(&Foo::sum<int>))
            .def("sum", pybind11::overload_cast<const Eigen::MatrixXf &, const Eigen::MatrixXf &>
                    (&Foo::sum<std::vector<Eigen::MatrixXf>>));
}

@jiwaszki
Copy link
Contributor

jiwaszki commented Jul 10, 2021

Thanks @sun1638650145 , that shed some light on the problem.

Once again I will advise to use lambdas in here. You also need to add all possible returning types to std::variant even if they won't be returned by your design. Lambdas will correctly deduce all overloads. Compiler will probably complain about no return at the end - I leave it up to you and your code.

Looking forward to get your feedback on it. Code is here on gist:
https://gist.github.com/jiwaszki/adeb35a922b37224087c749eb17bceb2

Note: If you want to invoke the methods like Foo.sum() instead of creating the object, you need to make function static and seed static const as well. Then bind it like this return Foo::sum<int>(a, b); and remove Foo self&.

@sun1638650145
Copy link
Author

@jiwaszki Thanks for your help, I solved my problem.

@dongshiwen1998
Copy link

@jiwaszki感谢您的帮助,我解决了我的问题。

thanks for your help

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants