PyAPI: add optional filter argument to KDTree.find
This commit is contained in:
parent
54b95c30ae
commit
9964eed9ac
|
@ -189,26 +189,57 @@ static PyObject *py_kdtree_balance(PyKDTree *self)
|
|||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
struct PyKDTree_NearestData {
|
||||
PyObject *py_filter;
|
||||
bool is_error;
|
||||
};
|
||||
|
||||
static int py_find_nearest_cb(void *user_data, int index, const float co[3], float dist_sq)
|
||||
{
|
||||
UNUSED_VARS(co, dist_sq);
|
||||
|
||||
struct PyKDTree_NearestData *data = user_data;
|
||||
|
||||
PyObject *py_args = PyTuple_New(1);
|
||||
PyTuple_SET_ITEM(py_args, 0, PyLong_FromLong(index));
|
||||
PyObject *result = PyObject_CallObject(data->py_filter, py_args);
|
||||
Py_DECREF(py_args);
|
||||
|
||||
if (result) {
|
||||
bool use_node;
|
||||
int ok = PyC_ParseBool(result, &use_node);
|
||||
Py_DECREF(result);
|
||||
if (ok) {
|
||||
return (int)use_node;
|
||||
}
|
||||
}
|
||||
|
||||
data->is_error = true;
|
||||
return -1;
|
||||
}
|
||||
|
||||
PyDoc_STRVAR(py_kdtree_find_doc,
|
||||
".. method:: find(co)\n"
|
||||
".. method:: find(co, filter=None)\n"
|
||||
"\n"
|
||||
" Find nearest point to ``co``.\n"
|
||||
"\n"
|
||||
" :arg co: 3d coordinates.\n"
|
||||
" :type co: float triplet\n"
|
||||
" :arg filter: function which takes an index and returns True for indices to include in the search.\n"
|
||||
" :type filter: callable\n"
|
||||
" :return: Returns (:class:`Vector`, index, distance).\n"
|
||||
" :rtype: :class:`tuple`\n"
|
||||
);
|
||||
static PyObject *py_kdtree_find(PyKDTree *self, PyObject *args, PyObject *kwargs)
|
||||
{
|
||||
PyObject *py_co;
|
||||
PyObject *py_co, *py_filter = NULL;
|
||||
float co[3];
|
||||
KDTreeNearest nearest;
|
||||
const char *keywords[] = {"co", NULL};
|
||||
const char *keywords[] = {"co", "filter", NULL};
|
||||
|
||||
if (!PyArg_ParseTupleAndKeywords(
|
||||
args, kwargs, (char *) "O:find", (char **)keywords,
|
||||
&py_co))
|
||||
args, kwargs, (char *) "O|O:find", (char **)keywords,
|
||||
&py_co, &py_filter))
|
||||
{
|
||||
return NULL;
|
||||
}
|
||||
|
@ -221,10 +252,26 @@ static PyObject *py_kdtree_find(PyKDTree *self, PyObject *args, PyObject *kwargs
|
|||
return NULL;
|
||||
}
|
||||
|
||||
|
||||
nearest.index = -1;
|
||||
|
||||
BLI_kdtree_find_nearest(self->obj, co, &nearest);
|
||||
if (py_filter == NULL) {
|
||||
BLI_kdtree_find_nearest(self->obj, co, &nearest);
|
||||
}
|
||||
else {
|
||||
struct PyKDTree_NearestData data = {0};
|
||||
|
||||
data.py_filter = py_filter;
|
||||
data.is_error = false;
|
||||
|
||||
BLI_kdtree_find_nearest_cb(
|
||||
self->obj, co,
|
||||
py_find_nearest_cb, &data,
|
||||
&nearest);
|
||||
|
||||
if (data.is_error) {
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
return kdtree_nearest_to_py_and_check(&nearest);
|
||||
}
|
||||
|
|
|
@ -240,17 +240,23 @@ class QuaternionTesting(unittest.TestCase):
|
|||
|
||||
|
||||
class KDTreeTesting(unittest.TestCase):
|
||||
|
||||
@staticmethod
|
||||
def kdtree_create_grid_3d(tot):
|
||||
k = kdtree.KDTree(tot * tot * tot)
|
||||
def kdtree_create_grid_3d_data(tot):
|
||||
index = 0
|
||||
mul = 1.0 / (tot - 1)
|
||||
for x in range(tot):
|
||||
for y in range(tot):
|
||||
for z in range(tot):
|
||||
k.insert((x * mul, y * mul, z * mul), index)
|
||||
yield (x * mul, y * mul, z * mul), index
|
||||
index += 1
|
||||
|
||||
@staticmethod
|
||||
def kdtree_create_grid_3d(tot, *, filter_fn=None):
|
||||
k = kdtree.KDTree(tot * tot * tot)
|
||||
for co, index in KDTreeTesting.kdtree_create_grid_3d_data(tot):
|
||||
if (filter_fn is not None) and (not filter_fn(co, index)):
|
||||
continue
|
||||
k.insert(co, index)
|
||||
k.balance()
|
||||
return k
|
||||
|
||||
|
@ -327,6 +333,49 @@ class KDTreeTesting(unittest.TestCase):
|
|||
ret = k.find_n((1.0,) * 3, tot)
|
||||
self.assertEqual(len(ret), tot)
|
||||
|
||||
def test_kdtree_grid_filter_simple(self):
|
||||
size = 10
|
||||
k = self.kdtree_create_grid_3d(size)
|
||||
|
||||
# filter exact index
|
||||
ret_regular = k.find((1.0,) * 3)
|
||||
ret_filter = k.find((1.0,) * 3, filter=lambda i: i == ret_regular[1])
|
||||
self.assertEqual(ret_regular, ret_filter)
|
||||
ret_filter = k.find((-1.0,) * 3, filter=lambda i: i == ret_regular[1])
|
||||
self.assertEqual(ret_regular[:2], ret_filter[:2]) # ignore distance
|
||||
|
||||
def test_kdtree_grid_filter_pairs(self):
|
||||
size = 10
|
||||
k_all = self.kdtree_create_grid_3d(size)
|
||||
k_odd = self.kdtree_create_grid_3d(size, filter_fn=lambda co, i: (i % 2) == 1)
|
||||
k_evn = self.kdtree_create_grid_3d(size, filter_fn=lambda co, i: (i % 2) == 0)
|
||||
|
||||
samples = 5
|
||||
mul = 1 / (samples - 1)
|
||||
for x in range(samples):
|
||||
for y in range(samples):
|
||||
for z in range(samples):
|
||||
co = (x * mul, y * mul, z * mul)
|
||||
|
||||
ret_regular = k_odd.find(co)
|
||||
self.assertEqual(ret_regular[1] % 2, 1)
|
||||
ret_filter = k_all.find(co, lambda i: (i % 2) == 1)
|
||||
self.assertEqual(ret_regular, ret_filter)
|
||||
|
||||
ret_regular = k_evn.find(co)
|
||||
self.assertEqual(ret_regular[1] % 2, 0)
|
||||
ret_filter = k_all.find(co, lambda i: (i % 2) == 0)
|
||||
self.assertEqual(ret_regular, ret_filter)
|
||||
|
||||
|
||||
# filter out all values (search odd tree for even values and the reverse)
|
||||
co = (0,) * 3
|
||||
ret_filter = k_odd.find(co, lambda i: (i % 2) == 0)
|
||||
self.assertEqual(ret_filter[1], None)
|
||||
|
||||
ret_filter = k_evn.find(co, lambda i: (i % 2) == 1)
|
||||
self.assertEqual(ret_filter[1], None)
|
||||
|
||||
def test_kdtree_invalid_size(self):
|
||||
with self.assertRaises(ValueError):
|
||||
kdtree.KDTree(-1)
|
||||
|
@ -342,6 +391,21 @@ class KDTreeTesting(unittest.TestCase):
|
|||
with self.assertRaises(RuntimeError):
|
||||
k.find(co)
|
||||
|
||||
def test_kdtree_invalid_filter(self):
|
||||
k = kdtree.KDTree(1)
|
||||
k.insert((0,) * 3, 0)
|
||||
k.balance()
|
||||
# not callable
|
||||
with self.assertRaises(TypeError):
|
||||
k.find((0,) * 3, filter=None)
|
||||
# no args
|
||||
with self.assertRaises(TypeError):
|
||||
k.find((0,) * 3, filter=lambda: None)
|
||||
# bad return value
|
||||
with self.assertRaises(ValueError):
|
||||
k.find((0,) * 3, filter=lambda i: None)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
sys.argv = [__file__] + (sys.argv[sys.argv.index("--") + 1:] if "--" in sys.argv else [])
|
||||
|
|
Loading…
Reference in New Issue