之前两篇随笔介绍了kd树的原理,并⽤python实现了kd树的构建和搜索,具体可以参考
kd树常与knn算法联系在⼀起,knn算法通常要搜索k近邻,⽽不仅仅是最近邻,下⾯的代码将利⽤kd树搜索⽬标点的k个近邻。 ⾸先还是创建⼀个类,⽤于保存结点的值,左右⼦树,以及⽤于划分左右⼦树的切分轴class decisionnode:
def __init__(self,value=None,col=None,rb=None,lb=None): self.value=value self.col=col self.rb=rb self.lb=lb
切分点为坐标轴上的中值,下⾯代码求得⼀个序列的中值def median(x): n=len(x) x=list(x)
x_order=sorted(x)
return x_order[n//2],x.index(x_order[n//2])
然后按照左⼦树⼤于切分点,右⼦树⼩于切分点的规则构造kd树,其中data是输⼊的数据#以j列的中值划分数据,左⼩右⼤,j=节点深度%列数 def buildtree(x,j=0): rb=[] lb=[]
m,n=x.shape
if m==0: return None
edge,row=median(x[:,j].copy()) for i in range(m): if x[i][j]>edge: rb.append(i) if x[i][j] return decisionnode(x[row,:],j,rightBranch,leftBranch) 接下来就是搜索树得到k近邻的过程,与搜索最近邻的过程⼤致相同,需要创建⼀个字典knears,⽤于存储k近邻的点以及与⽬标点的距离(欧⽒距离) 搜索的过程为: (1)第⼀步还是遍历树,找到⽬标点所属区域对应的叶节点 (2)从叶结点依次向上回退,按照寻找最近邻点的⽅法回退到⽗节点,并判断其另⼀个⼦节点对区域内是否可能存在k近邻点,具体的,在每个结点上进⾏以下操作: (a)如果字典中的成员个数不⾜k个,将该结点加⼊字典 (b)如果字典中的成员不少于k个,判断该结点与⽬标结点之间的距离是否不⼤于字典中各结点所对应距离的的最⼤值,如果不⼤于,便将其加⼊到字典中 (c)对于⽗节点来说,如果⽬标点与其切分轴之间的距离不⼤于字典中各结点所对应距离的的最⼤值,便需要访问该⽗节点的另⼀个⼦节点 (3)每当字典中增加新成员,就按距离值对字典进⾏降序排序,将得到的列表赋值给poinelist,pointlist[0][1]便是字典中各结点所对应距离的最⼤值 (4)当回退到根节点并完成对其操作时,pointlist中后k个结点就是⽬标点的k近邻 代码如下: #搜索树:输出⽬标点的近邻点def traveltree(node,aim): global pointlist #存储排序后的k近邻点和对应距离 if node==None: return col=node.col if aim[col]>node.value[col]: traveltree(node.rb,aim) if aim[col] knears.setdefault(tuple(node.value.tolist()),dis) pointlist=sorted(knears.items(),key=lambda item: item[1],reverse=True) if node.rb!=None or node.lb!=None: if abs(aim[node.col] - node.value[node.col]) < pointlist[0][1]: if aim[node.col] 1 import numpy as np 2 from numpy import array 3 class decisionnode: 4 def __init__(self,value=None,col=None,rb=None,lb=None): 5 self.value=value 6 self.col=col 7 self.rb=rb 8 self.lb=lb 9 10 #读取数据并将数据转换为矩阵形式 11 def readdata(filename): 12 data=open(filename).readlines()13 x=[] 14 for line in data: 15 line=line.strip().split('\')16 x_i=[] 17 for num in line:18 num=float(num)19 x_i.append(num)20 x.append(x_i)21 x=array(x)22 return x23 24 #求序列的中值 25 def median(x):26 n=len(x)27 x=list(x) 28 x_order=sorted(x) 29 return x_order[n//2],x.index(x_order[n//2])30 31 #以j列的中值划分数据,左⼩右⼤,j=节点深度%列数 32 def buildtree(x,j=0):33 rb=[]34 lb=[] 35 m,n=x.shape 36 if m==0: return None 37 edge,row=median(x[:,j].copy())38 for i in range(m):39 if x[i][j]>edge: 40 rb.append(i)41 if x[i][j] 47 return decisionnode(x[row,:],j,rightBranch,leftBranch)48 49 #搜索树:输出⽬标点的近邻点50 def traveltree(node,aim): 51 global pointlist #存储排序后的k近邻点和对应距离52 if node==None: return 53 col=node.col 54 if aim[col]>node.value[col]:55 traveltree(node.rb,aim)56 if aim[col] 63 knears.setdefault(tuple(node.value.tolist()),dis) 64 pointlist=sorted(knears.items(),key=lambda item: item[1],reverse=True) 65 if node.rb!=None or node.lb!=None: 66 if abs(aim[node.col] - node.value[node.col]) < pointlist[0][1]:67 if aim[node.col] 73 def dist(x1, x2): #欧式距离的计算 74 return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5 75 76 knears={} 77 k=int(input('请输⼊k的值'))78 if k<2: print('k不能是1')79 global pointlist80 pointlist=[] 81 file=input('请输⼊数据⽂件地址')82 data=readdata(file)83 tree=buildtree(data) 84 tmp=input('请输⼊⽬标点')85 tmp=tmp.split(',')86 aim=[] 87 for num in tmp:88 num=float(num)89 aim.append(num)90 aim=tuple(aim) 91 pointlist=traveltree(tree,aim)92 for point in pointlist[-k:]:93 print(point)kdtree 因篇幅问题不能全部显示,请点此查看更多更全内容