阅读 51

KD树

KD(k-dimensional)树的概念自1975年提出,试图解决的是在k维空间为数据集建立索引的问题。依上文所述,已知样本空间如何快速查询得到其近邻?唯有以空间换时间,建立索引便是计算机世界的解决之道。但是索引建立的方式各有不同,kd树只是是其中一种。它的思想如同分治法,即:利用已有数据对k维空间进行切分。

在机器学习KNN中,KD树也是必不可少的理论基础部分,分文介绍并提供示例代码

参考:

 

概述

二叉树在时间复杂度上是O(logN),远远优于全遍历算法。对于该树,可以在空间上理解:树的每个节点把对应父节点切成的空间再切分,从而形成各个不同的子空间。查找某点的所在位置时,就变成了查找点所在子空间。而KD树引申于二叉树

以二维KD树为例,假设有6个二维数据点{(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)}。

将二维的平面想像成一块方型蛋糕,kd树构建就是面点师要将蛋糕切成上面示意图的模样。先将平面上的六个点在蛋糕上做好标记。

 

1. KD树的构建

以本例的第一次切分为例,需要知道以x轴还是y轴进行切分比较好,需要判断两个维度的方差,选择最大的来切分

计算可得,x方差较大,按x进行切分。

考虑到让二叉树的深度尽量小,使用二分原则进行划分。即按照中间索引的点作为根节点,剩下按照大小各分左右。

之后,将分好的两个数据也按照此原则进行划分,最终构建出KD树

结果如图所示:

 

2. KD树的查找

在k-d树中进行数据的查找也是特征匹配的重要环节,其目的是检索在k-d树中与查询点距离最近的数据点。

回到面点师切好的糕点平面图,用目标数据在kd树中寻找最近邻时,最核心的两个部分是: 

1 寻找近似点-寻找最近邻的叶子节点作为目标数据的近似最近点。

2 回溯-以目标数据和最近邻的近似点的距离沿树根部进行回溯和迭代。

回溯和迭代的目的是因为找到的点不一定就是最邻近的,最邻近肯定距离查询点更近,应该位于以查询点为圆心且通过叶子节点的圆域内。为了找到真正的最近邻还需要进行‘回溯‘操作:算法沿搜索路径反向查找是否有距离查询点更近的数据点。

 

 

代码

 1 import math
 2 
 3 import numpy as np
 4 
 5 class KdNode:
 6     data = None
 7     left = None
 8     right = None
 9 
10     def __init__(self, data, left, right):
11         self.data = data
12         self.left = left
13         self.right = right
14 
15 
16 def distance(p1, p2):
17     dimension = p1.size
18     sum = 0.
19     for i in range(0, dimension):
20         sum += (p1[i] - p2[i]) ** 2
21     return math.sqrt(sum)
22 
23 
24 class Kdtree:
25 
26     def __init__(self, data):
27         self.tree = self.buildChildTree(np.array(data))
28 
29     def buildChildTree(self, data):
30         if len(data) == 0 or data is None:
31             return None
32         dimension = data.ndim
33         if data.size == dimension:  # data.shape[1] 对一维情况会报错,出此下策。。。
34             return KdNode(data[0,], None, None)
35         vars = []
36         for i in range(dimension):
37             vars.append(np.var(data[:,i]))
38         max_dimension = vars.index(max(vars))
39         data_sorted = data[np.argsort(data[:,max_dimension]),:]
40         mid_i = data_sorted.shape[0] // 2
41         n = KdNode(None, None, None)
42         n.left = self.buildChildTree(data_sorted[:mid_i,])
43         n.right = self.buildChildTree(data_sorted[mid_i+1:,])
44         n.data = data_sorted[mid_i,]
45         return n
46 
47     def findNearestPoint(self, point):
48         cur = self.tree
49         nearest = self.tree.data
50         search_path = []
51         while 1:
52             search_path.append(nearest)
53             root = cur
54             if cur.left is None and cur.right is None:
55                 break
56             if root.left:
57                 if distance(root.left.data, point) < distance(nearest, point):
58                     cur = root.left
59                     nearest = root.left.data
60                     continue
61             elif root.right:
62                 if distance(root.right.data, point) < distance(nearest, point):
63                     cur = root.right
64                     nearest = root.right.data
65                     continue
66             break
67         return nearest, search_path
68 
69 
70 if __name__ == __main__:
71     kd = Kdtree([(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)])
72     n,p = kd.findNearestPoint(np.array([2, 4.5]))
73     print(nearest point: , n,    search path: ,p)   

 

原文:https://www.cnblogs.com/Asp1rant/p/15259429.html

文章分类
代码人生
文章标签
版权声明:本站是系统测试站点,无实际运营。本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 XXXXXXo@163.com 举报,一经查实,本站将立刻删除。
相关推荐