937 template<
typename T,
typename DistanceType>
984 const DatasetAdaptor& inputData,
987 : m_leaf_max_size(0),
989 index_params(params),
993 distance(inputData) {
1001 m_leaf_max_size = params.leaf_max_size;
1005 m_size = dataset.kdtree_get_point_count();
1031 computeBoundingBox(root_bbox);
1032 root_node = divideTree(0, m_size, root_bbox);
1046 return static_cast<size_t>(DIM > 0 ? DIM : dim);
1054 return pool.usedMemory + pool.wastedMemory
1055 + dataset.kdtree_get_point_count() *
sizeof(IndexType);
1072 template<
typename RESULTSET>
1077 throw std::runtime_error(
1078 "[nanoflann] findNeighbors() called before building the index or no data points.");
1079 float epsError = 1 + searchParams.
eps;
1082 dists.
assign((DIM > 0 ? DIM : dim), 0);
1083 DistanceType distsq = computeInitialDistances(vec, dists);
1084 searchLevel(result, vec, root_node, distsq, dists, epsError);
1094 const size_t num_closest, IndexType *out_indices,
1096 const int = 10)
const {
1098 resultSet.
init(out_indices, out_distances_sq);
1116 std::vector<std::pair<IndexType, DistanceType> >& IndicesDists,
1119 this->findNeighbors(resultSet, query_point, searchParams);
1124 return resultSet.
size();
1128 std::list<IndexType>& IndicesDists,
1131 this->findNeighbors(resultList, query_point, searchParams);
1134 IndicesDists.sort();
1136 return resultList.
size();
1145 m_size = dataset.kdtree_get_point_count();
1146 if (vind.size() != m_size)
1147 vind.resize(m_size);
1148 for (
size_t i = 0; i < m_size; i++)
1153 inline ElementType dataset_get(
size_t idx,
int component)
const {
1154 return dataset.kdtree_get_pt(idx, component);
1157 void save_tree(FILE* stream, NodePtr tree) {
1159 if (tree->child1 != NULL) {
1160 save_tree(stream, tree->child1);
1162 if (tree->child2 != NULL) {
1163 save_tree(stream, tree->child2);
1167 void load_tree(FILE* stream, NodePtr& tree) {
1168 tree =
pool.allocate<Node>();
1170 if (tree->child1 != NULL) {
1171 load_tree(stream, tree->child1);
1173 if (tree->child2 != NULL) {
1174 load_tree(stream, tree->child2);
1178 void computeBoundingBox(BoundingBox& bbox) {
1179 bbox.resize((DIM > 0 ? DIM : dim));
1180 if (dataset.kdtree_get_bbox(bbox)) {
1183 const size_t N = dataset.kdtree_get_point_count();
1185 throw std::runtime_error(
1186 "[nanoflann] computeBoundingBox() called but no data points found.");
1187 for (
int i = 0; i < (DIM > 0 ? DIM : dim); ++i) {
1188 bbox[i].low = bbox[i].high = dataset_get(0, i);
1190 for (
size_t k = 1; k < N; ++k) {
1191 for (
int i = 0; i < (DIM > 0 ? DIM : dim); ++i) {
1192 if (dataset_get(k, i) < bbox[i].low)
1193 bbox[i].low = dataset_get(k, i);
1194 if (dataset_get(k, i) > bbox[i].high)
1195 bbox[i].high = dataset_get(k, i);
1210 NodePtr divideTree(
const IndexType left,
const IndexType right,
1211 BoundingBox& bbox) {
1212 NodePtr node =
pool.allocate<Node>();
1215 if ((right - left) <= m_leaf_max_size) {
1216 node->child1 = node->child2 = NULL;
1217 node->lr.left = left;
1218 node->lr.right = right;
1221 for (
int i = 0; i < (DIM > 0 ? DIM : dim); ++i) {
1222 bbox[i].low = dataset_get(vind[left], i);
1223 bbox[i].high = dataset_get(vind[left], i);
1225 for (IndexType k = left + 1; k < right; ++k) {
1226 for (
int i = 0; i < (DIM > 0 ? DIM : dim); ++i) {
1227 if (bbox[i].low > dataset_get(vind[k], i))
1228 bbox[i].low = dataset_get(vind[k], i);
1229 if (bbox[i].high < dataset_get(vind[k], i))
1230 bbox[i].high = dataset_get(vind[k], i);
1236 DistanceType cutval;
1237 middleSplit_(&vind[0] + left, right - left, idx, cutfeat, cutval, bbox);
1239 node->sub.divfeat = cutfeat;
1241 BoundingBox left_bbox(bbox);
1242 left_bbox[cutfeat].high = cutval;
1243 node->child1 = divideTree(left, left + idx, left_bbox);
1245 BoundingBox right_bbox(bbox);
1246 right_bbox[cutfeat].low = cutval;
1247 node->child2 = divideTree(left + idx, right, right_bbox);
1249 node->sub.divlow = left_bbox[cutfeat].high;
1250 node->sub.divhigh = right_bbox[cutfeat].low;
1252 for (
int i = 0; i < (DIM > 0 ? DIM : dim); ++i) {
1253 bbox[i].low = std::min(left_bbox[i].low, right_bbox[i].low);
1254 bbox[i].high = std::max(left_bbox[i].high, right_bbox[i].high);
1261 void computeMinMax(IndexType* ind, IndexType count,
int element,
1262 ElementType& min_elem, ElementType& max_elem) {
1263 min_elem = dataset_get(ind[0], element);
1264 max_elem = dataset_get(ind[0], element);
1265 for (IndexType i = 1; i < count; ++i) {
1266 ElementType val = dataset_get(ind[i], element);
1274 void middleSplit(IndexType* ind, IndexType count, IndexType& index,
1275 int& cutfeat, DistanceType& cutval,
1276 const BoundingBox& bbox) {
1278 ElementType max_span = bbox[0].high - bbox[0].low;
1280 cutval = (bbox[0].high + bbox[0].low) / 2;
1281 for (
int i = 1; i < (DIM > 0 ? DIM : dim); ++i) {
1282 ElementType span = bbox[i].low - bbox[i].low;
1283 if (span > max_span) {
1286 cutval = (bbox[i].high + bbox[i].low) / 2;
1291 ElementType min_elem, max_elem;
1292 computeMinMax(ind, count, cutfeat, min_elem, max_elem);
1293 cutval = (min_elem + max_elem) / 2;
1294 max_span = max_elem - min_elem;
1298 for (
size_t i = 0; i < (DIM > 0 ? DIM : dim); ++i) {
1301 ElementType span = bbox[i].high - bbox[i].low;
1302 if (span > max_span) {
1303 computeMinMax(ind, count, i, min_elem, max_elem);
1304 span = max_elem - min_elem;
1305 if (span > max_span) {
1308 cutval = (min_elem + max_elem) / 2;
1312 IndexType lim1, lim2;
1313 planeSplit(ind, count, cutfeat, cutval, lim1, lim2);
1315 if (lim1 > count / 2)
1317 else if (lim2 < count / 2)
1323 void middleSplit_(IndexType* ind, IndexType count, IndexType& index,
1324 int& cutfeat, DistanceType& cutval,
1325 const BoundingBox& bbox) {
1326 const DistanceType EPS =
static_cast<DistanceType
>(0.00001);
1327 ElementType max_span = bbox[0].high - bbox[0].low;
1328 for (
int i = 1; i < (DIM > 0 ? DIM : dim); ++i) {
1329 ElementType span = bbox[i].high - bbox[i].low;
1330 if (span > max_span) {
1334 ElementType max_spread = -1;
1336 for (
int i = 0; i < (DIM > 0 ? DIM : dim); ++i) {
1337 ElementType span = bbox[i].high - bbox[i].low;
1338 if (span > (1 - EPS) * max_span) {
1339 ElementType min_elem, max_elem;
1340 computeMinMax(ind, count, cutfeat, min_elem, max_elem);
1341 ElementType spread = max_elem - min_elem;
1343 if (spread > max_spread) {
1345 max_spread = spread;
1350 DistanceType split_val = (bbox[cutfeat].low + bbox[cutfeat].high) / 2;
1351 ElementType min_elem, max_elem;
1352 computeMinMax(ind, count, cutfeat, min_elem, max_elem);
1354 if (split_val < min_elem)
1356 else if (split_val > max_elem)
1361 IndexType lim1, lim2;
1362 planeSplit(ind, count, cutfeat, cutval, lim1, lim2);
1364 if (lim1 > count / 2)
1366 else if (lim2 < count / 2)
1381 void planeSplit(IndexType* ind,
const IndexType count,
int cutfeat,
1382 DistanceType cutval, IndexType& lim1, IndexType& lim2) {
1385 IndexType right = count - 1;
1387 while (left <= right && dataset_get(ind[left], cutfeat) < cutval)
1389 while (right && left <= right
1390 && dataset_get(ind[right], cutfeat) >= cutval)
1392 if (left > right || !right)
1394 std::swap(ind[left], ind[right]);
1404 while (left <= right && dataset_get(ind[left], cutfeat) <= cutval)
1406 while (right && left <= right && dataset_get(ind[right], cutfeat) > cutval)
1408 if (left > right || !right)
1410 std::swap(ind[left], ind[right]);
1417 DistanceType computeInitialDistances(
const ElementType* vec,
1418 distance_vector_t& dists)
const {
1420 DistanceType distsq = 0.0;
1422 for (
int i = 0; i < (DIM > 0 ? DIM : dim); ++i) {
1423 if (vec[i] < root_bbox[i].low) {
1424 dists[i] =
distance.accum_dist(vec[i], root_bbox[i].low, i);
1427 if (vec[i] > root_bbox[i].high) {
1428 dists[i] =
distance.accum_dist(vec[i], root_bbox[i].high, i);
1440 template<
class RESULTSET>
1441 void searchLevel(RESULTSET& result_set,
const ElementType* vec,
1442 const NodePtr node, DistanceType mindistsq,
1443 distance_vector_t& dists,
const float epsError)
const {
1445 if ((node->child1 == NULL) && (node->child2 == NULL)) {
1447 DistanceType worst_dist = result_set.worstDist();
1448 for (IndexType i = node->lr.left; i < node->lr.right; ++i) {
1449 const IndexType index = vind[i];
1450 DistanceType dist =
distance(vec, index, (DIM > 0 ? DIM : dim));
1451 if (dist < worst_dist) {
1452 result_set.addPoint(dist, vind[i]);
1459 int idx = node->sub.divfeat;
1460 ElementType val = vec[idx];
1461 DistanceType diff1 = val - node->sub.divlow;
1462 DistanceType diff2 = val - node->sub.divhigh;
1466 DistanceType cut_dist;
1467 if ((diff1 + diff2) < 0) {
1468 bestChild = node->child1;
1469 otherChild = node->child2;
1470 cut_dist =
distance.accum_dist(val, node->sub.divhigh, idx);
1472 bestChild = node->child2;
1473 otherChild = node->child1;
1474 cut_dist =
distance.accum_dist(val, node->sub.divlow, idx);
1478 searchLevel(result_set, vec, bestChild, mindistsq, dists, epsError);
1480 DistanceType dst = dists[idx];
1481 mindistsq = mindistsq + cut_dist - dst;
1482 dists[idx] = cut_dist;
1483 if (mindistsq * epsError <= result_set.worstDist()) {
1484 searchLevel(result_set, vec, otherChild, mindistsq, dists, epsError);
1500 save_tree(stream, root_node);
1513 load_tree(stream, root_node);