PyAPI: add optional filter argument to KDTree.find

This commit is contained in:
Campbell Barton 2015-12-06 21:33:39 +11:00
parent 54b95c30ae
commit 9964eed9ac
2 changed files with 122 additions and 11 deletions

View File

@ -189,26 +189,57 @@ static PyObject *py_kdtree_balance(PyKDTree *self)
Py_RETURN_NONE;
}
struct PyKDTree_NearestData {
PyObject *py_filter;
bool is_error;
};
static int py_find_nearest_cb(void *user_data, int index, const float co[3], float dist_sq)
{
UNUSED_VARS(co, dist_sq);
struct PyKDTree_NearestData *data = user_data;
PyObject *py_args = PyTuple_New(1);
PyTuple_SET_ITEM(py_args, 0, PyLong_FromLong(index));
PyObject *result = PyObject_CallObject(data->py_filter, py_args);
Py_DECREF(py_args);
if (result) {
bool use_node;
int ok = PyC_ParseBool(result, &use_node);
Py_DECREF(result);
if (ok) {
return (int)use_node;
}
}
data->is_error = true;
return -1;
}
PyDoc_STRVAR(py_kdtree_find_doc,
".. method:: find(co)\n"
".. method:: find(co, filter=None)\n"
"\n"
" Find nearest point to ``co``.\n"
"\n"
" :arg co: 3d coordinates.\n"
" :type co: float triplet\n"
" :arg filter: function which takes an index and returns True for indices to include in the search.\n"
" :type filter: callable\n"
" :return: Returns (:class:`Vector`, index, distance).\n"
" :rtype: :class:`tuple`\n"
);
static PyObject *py_kdtree_find(PyKDTree *self, PyObject *args, PyObject *kwargs)
{
PyObject *py_co;
PyObject *py_co, *py_filter = NULL;
float co[3];
KDTreeNearest nearest;
const char *keywords[] = {"co", NULL};
const char *keywords[] = {"co", "filter", NULL};
if (!PyArg_ParseTupleAndKeywords(
args, kwargs, (char *) "O:find", (char **)keywords,
&py_co))
args, kwargs, (char *) "O|O:find", (char **)keywords,
&py_co, &py_filter))
{
return NULL;
}
@ -221,10 +252,26 @@ static PyObject *py_kdtree_find(PyKDTree *self, PyObject *args, PyObject *kwargs
return NULL;
}
nearest.index = -1;
BLI_kdtree_find_nearest(self->obj, co, &nearest);
if (py_filter == NULL) {
BLI_kdtree_find_nearest(self->obj, co, &nearest);
}
else {
struct PyKDTree_NearestData data = {0};
data.py_filter = py_filter;
data.is_error = false;
BLI_kdtree_find_nearest_cb(
self->obj, co,
py_find_nearest_cb, &data,
&nearest);
if (data.is_error) {
return NULL;
}
}
return kdtree_nearest_to_py_and_check(&nearest);
}

View File

@ -240,17 +240,23 @@ class QuaternionTesting(unittest.TestCase):
class KDTreeTesting(unittest.TestCase):
@staticmethod
def kdtree_create_grid_3d(tot):
k = kdtree.KDTree(tot * tot * tot)
def kdtree_create_grid_3d_data(tot):
index = 0
mul = 1.0 / (tot - 1)
for x in range(tot):
for y in range(tot):
for z in range(tot):
k.insert((x * mul, y * mul, z * mul), index)
yield (x * mul, y * mul, z * mul), index
index += 1
@staticmethod
def kdtree_create_grid_3d(tot, *, filter_fn=None):
k = kdtree.KDTree(tot * tot * tot)
for co, index in KDTreeTesting.kdtree_create_grid_3d_data(tot):
if (filter_fn is not None) and (not filter_fn(co, index)):
continue
k.insert(co, index)
k.balance()
return k
@ -327,6 +333,49 @@ class KDTreeTesting(unittest.TestCase):
ret = k.find_n((1.0,) * 3, tot)
self.assertEqual(len(ret), tot)
def test_kdtree_grid_filter_simple(self):
size = 10
k = self.kdtree_create_grid_3d(size)
# filter exact index
ret_regular = k.find((1.0,) * 3)
ret_filter = k.find((1.0,) * 3, filter=lambda i: i == ret_regular[1])
self.assertEqual(ret_regular, ret_filter)
ret_filter = k.find((-1.0,) * 3, filter=lambda i: i == ret_regular[1])
self.assertEqual(ret_regular[:2], ret_filter[:2]) # ignore distance
def test_kdtree_grid_filter_pairs(self):
size = 10
k_all = self.kdtree_create_grid_3d(size)
k_odd = self.kdtree_create_grid_3d(size, filter_fn=lambda co, i: (i % 2) == 1)
k_evn = self.kdtree_create_grid_3d(size, filter_fn=lambda co, i: (i % 2) == 0)
samples = 5
mul = 1 / (samples - 1)
for x in range(samples):
for y in range(samples):
for z in range(samples):
co = (x * mul, y * mul, z * mul)
ret_regular = k_odd.find(co)
self.assertEqual(ret_regular[1] % 2, 1)
ret_filter = k_all.find(co, lambda i: (i % 2) == 1)
self.assertEqual(ret_regular, ret_filter)
ret_regular = k_evn.find(co)
self.assertEqual(ret_regular[1] % 2, 0)
ret_filter = k_all.find(co, lambda i: (i % 2) == 0)
self.assertEqual(ret_regular, ret_filter)
# filter out all values (search odd tree for even values and the reverse)
co = (0,) * 3
ret_filter = k_odd.find(co, lambda i: (i % 2) == 0)
self.assertEqual(ret_filter[1], None)
ret_filter = k_evn.find(co, lambda i: (i % 2) == 1)
self.assertEqual(ret_filter[1], None)
def test_kdtree_invalid_size(self):
with self.assertRaises(ValueError):
kdtree.KDTree(-1)
@ -342,6 +391,21 @@ class KDTreeTesting(unittest.TestCase):
with self.assertRaises(RuntimeError):
k.find(co)
def test_kdtree_invalid_filter(self):
k = kdtree.KDTree(1)
k.insert((0,) * 3, 0)
k.balance()
# not callable
with self.assertRaises(TypeError):
k.find((0,) * 3, filter=None)
# no args
with self.assertRaises(TypeError):
k.find((0,) * 3, filter=lambda: None)
# bad return value
with self.assertRaises(ValueError):
k.find((0,) * 3, filter=lambda i: None)
if __name__ == '__main__':
import sys
sys.argv = [__file__] + (sys.argv[sys.argv.index("--") + 1:] if "--" in sys.argv else [])