Skip to content

Commit

Permalink
Merge pull request #722 from fanosta/buffer-protocol
Browse files Browse the repository at this point in the history
support buffer protocol in Solver.add_clauses
  • Loading branch information
msoos authored Jun 24, 2023
2 parents fa131d1 + 76f33f4 commit 02fc5e8
Showing 1 changed file with 23 additions and 86 deletions.
109 changes: 23 additions & 86 deletions python/src/pycryptosat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,28 +364,21 @@ static int _add_clauses_from_array(Solver *self, const size_t array_length, cons
return 1;
}

static int _add_clauses_from_buffer_info(Solver *self, PyObject *buffer_info, const size_t itemsize)
static int _add_clauses_from_buffer(Solver *self, Py_buffer *view)
{
PyObject *py_array_length = PyTuple_GetItem(buffer_info, 1);
if (py_array_length == NULL) {
PyErr_SetString(PyExc_ValueError, "invalid clause array: could not get array length");
if (view->ndim != 1) {
PyErr_Format(PyExc_ValueError, "invalid clause array: expected 1-D array, got %d-D", view->ndim);
return 0;
}
long array_length = PyLong_AsLong(py_array_length);
if (array_length < 0) {
PyErr_SetString(PyExc_ValueError, "invalid clause array: could not get array length");
return 0;
}
PyObject *py_array_address = PyTuple_GetItem(buffer_info, 0);
if (py_array_address == NULL) {
PyErr_SetString(PyExc_ValueError, "invalid clause array: could not get array address");
return 0;
}
const void *array_address = PyLong_AsVoidPtr(py_array_address);
if (array_address == NULL) {
PyErr_SetString(PyExc_ValueError, "invalid clause array: could not get array address");
if (strcmp(view->format, "i") != 0 && strcmp(view->format, "l") != 0 && strcmp(view->format, "q") != 0) {
PyErr_Format(PyExc_ValueError, "invalid clause array: invalid format '%s'", view->format);
return 0;
}

void * array_address = view->buf;
size_t itemsize = view->itemsize;
size_t array_length = view->len / itemsize;

if (itemsize == sizeof(int)) {
return _add_clauses_from_array(self, array_length, (const int *) array_address);
}
Expand All @@ -399,74 +392,14 @@ static int _add_clauses_from_buffer_info(Solver *self, PyObject *buffer_info, co
return 0;
}

static int _check_array_typecode(PyObject *clauses)
{
PyObject *py_typecode = PyObject_GetAttrString(clauses, "typecode");
if (py_typecode == NULL) {
PyErr_SetString(PyExc_ValueError, "invalid clause array: typecode is NULL");
return 0;
}

PyObject *typecode_bytes = PyUnicode_AsASCIIString(py_typecode);
Py_DECREF(py_typecode);
if (typecode_bytes == NULL) {
PyErr_SetString(PyExc_ValueError, "invalid clause array: could not get typecode bytes");
return 0;
}

const char *typecode_cstr = PyBytes_AsString(typecode_bytes);
if (typecode_cstr == NULL) {
Py_DECREF(typecode_bytes);
PyErr_SetString(PyExc_ValueError, "invalid clause array: could not get typecode cstring");
return 0;
}
const char typecode = typecode_cstr[0];
if (typecode == '\0' || typecode_cstr[1] != '\0') {
PyErr_Format(PyExc_ValueError, "invalid clause array: invalid typecode '%s'", typecode_cstr);
Py_DECREF(typecode_bytes);
return 0;
}
Py_DECREF(typecode_bytes);
if (typecode != 'i' && typecode != 'l' && typecode != 'q') {
PyErr_Format(PyExc_ValueError, "invalid clause array: invalid typecode '%c'", typecode);
return 0;
}
return 1;
}

static int add_clauses_array(Solver *self, PyObject *clauses)
{
if (_check_array_typecode(clauses) == 0) {
return 0;
}
PyObject *py_itemsize = PyObject_GetAttrString(clauses, "itemsize");
if (py_itemsize == NULL) {
PyErr_SetString(PyExc_ValueError, "invalid clause array: itemsize is NULL");
return 0;
}
const long itemsize = PyLong_AsLong(py_itemsize);
Py_DECREF(py_itemsize);
if (itemsize < 0) {
PyErr_SetString(PyExc_ValueError, "invalid clause array: could not get itemsize");
return 0;
}
PyObject *buffer_info = PyObject_CallMethod(clauses, "buffer_info", NULL);
if (buffer_info == NULL) {
PyErr_SetString(PyExc_ValueError, "invalid clause array: buffer_info is NULL");
return 0;
}
int ret = _add_clauses_from_buffer_info(self, buffer_info, itemsize);
Py_DECREF(buffer_info);
return ret;
}

PyDoc_STRVAR(add_clauses_doc,
"add_clauses(clauses)\n\
Add iterable of clauses to the solver.\n\
\n\
:param clauses: List of clauses. Each clause contains literals (ints)\n\
Alternatively, this can be a flat array.array (typecode 'i', 'l', or 'q')\n\
of zero separated and terminated clauses of literals (ints).\n\
Alternatively, this can be a flat array.array or other contiguous\n\
buffer (format 'i', 'l', or 'q') of zero separated and terminated\n\
clauses of literals (ints).\n\
:type clauses: <list> or <array.array>\n\
:return: None\n\
:rtype: <None>"
Expand All @@ -480,12 +413,16 @@ static PyObject* add_clauses(Solver *self, PyObject *args, PyObject *kwds)
return NULL;
}

if (
PyObject_HasAttr(clauses, PyUnicode_FromString("buffer_info")) &&
PyObject_HasAttr(clauses, PyUnicode_FromString("typecode")) &&
PyObject_HasAttr(clauses, PyUnicode_FromString("itemsize"))
) {
int ret = add_clauses_array(self, clauses);
if (PyObject_CheckBuffer(clauses)) {
Py_buffer view;
memset(&view, 0, sizeof(view));
if (PyObject_GetBuffer(clauses, &view, PyBUF_CONTIG_RO | PyBUF_FORMAT) != 0) {
return NULL;
}

int ret = _add_clauses_from_buffer(self, &view);
PyBuffer_Release(&view);

if (ret == 0 || PyErr_Occurred()) {
return 0;
}
Expand Down

0 comments on commit 02fc5e8

Please sign in to comment.