KD-Tree和KNN算法原理与实现(Python)

KD-Tree和KNN算法原理与实现(Python)

Mr.GGLS 744 2022-05-13

什么是KD-Tree

kd-tree

下面是来自[Geeksforgeeks](K Dimensional Tree | Set 1 (Search and Insert) - GeeksforGeeks)的解释

A K-D Tree(also called as K-Dimensional Tree) is a binary search tree where data in each node is a K-Dimensional point in space. In short, it is a space partitioning(details below) data structure for organizing points in a K-Dimensional space.

KD-Tree也就是一个在K维空间的二叉搜索树,在建立树的过程中,以当前节点为中心形成一个超平面划分空间为两半,左右子树分别在左右子空间划分获得

Why KD-Tree

KD-Tree可用于在k维空间的高效查询,最近邻、范围查询等。例如在k近邻算法中,一般查找某个节点一定半径内的最近邻的暴力解法需要将所有节点到指定节点的距离都算一遍,这在小规模的数据集还好,在百万上千万的节点量下还是显得太慢,而KD-Tree在查询过程中每次进入子树都会排除掉另一个子空间,最好情况是只需要O(logn)O(logn)的复杂度即可完成查询

如何划分KD-Tree

  • 首先确定当前节点(n维)(x0,x1,..,xd,..,xn)(x_0,x_1,..,x_d,..,x_n)划分空间所用的维度d
  • 将当前节点的第k维坐标xdx_d与其他未划分节点对应维度坐标进行比较,小于当前节点的去左子空间,否则去右子空间
  • 最开始的根节点以第0维开始划分,之后依次加1
  • 如果树的深度大于节点的维度,则第k次划分(k>n)所用的维度是k mod n
  • 绕了请看下面例子

举个栗子(来自GeeksForGeeks)

考虑如下节点

(3, 6), (17, 15), (13, 15), (6, 12), (9, 1), (2, 7), (10, 19)

我们从左到右依次插入上述节点到kd-tree中,节点只有两维(x和y)

  1. 根节点是(3, 6),划分维度为x
  2. 插入(17, 15),在x上大于3,所以划分到右子空间(17>3),(17, 15)的划分维度为y
  3. 插入(13, 15),由于在下上大于3,划分到右子空间(13>3),再和(17, 15)在y轴上比较,划分到(17, 15)的右子空间(15>=15)
  4. 插入(6, 12),由于在x上大于3,划分到右子空间(6>3),再和(17, 15)在y轴上比较,划分到(17, 15)的左子空间(12<15),(6, 12)的划分维度为x(树的深度大于节点维度 2 mod 2 = 0)
  5. 插入(9, 1),由于在x上大于3,划分到右子空间(9>3),再和(17, 15)在y轴上比较,划分到(17, 15)的左子空间(1<15),接着和(6, 12)在x轴上进行比较,被划分到右子空间(9>6),划分维度为y
  6. 依此类推
  7. 最终成树为
  8. image-20220501215818079
  9. 如果将树转化为二维坐标图,划分情况如下(每条虚线划分出两个子空间)
  10. image-20220501215932061

最近邻查询

看这老哥的1.3(懒得画图了)KD树详解及KD树最近邻算法_weixin_43312083的博客-CSDN博客_kd树最近邻搜索算法

KD-Tree和KNN算法的实现(Python)

Let’s build our own kd-tree !

国内博客大多都是C++实现(不知道是不是互相借鉴的。。。),而且就我搜索来看没有关于knn的实现,下面的代码是我参考了国外的一位教授发布到网上的教程和一位大佬博客里的knn python实现(找不到链接了😅)编写而成

我觉得应该或许大概可能已经很通俗易懂了🤔

import numpy as np


class KDTree:
    def __init__(self, P, d=0):
        n = len(P) # P为待建树的所有节点
        mid = n // 2 
        P.sort(key=lambda x: x[d]) # 给节点在第d维排个序
        self.point = P[mid] # 第mid个节点就是待划分的节点
        self.d = d # 当前节点划分的维度
        d = (d + 1) % len(P[0]) # 左右子节点的划分维度 d = d mod n
        self.left = self.right = None
        if mid > 0: # 左子空间还有节点
            self.left = KDTree(P[:mid], d)
        if n - mid - 1 > 0: # 右子空间还有节点
            self.right = KDTree(P[mid + 1:], d)

    def print_tree(self): # 前序遍历打印
        print(self.point)
        if self.left:
            print("left: ")
            self.left.print_tree()
        if self.right:
            print("right: ")
            self.right.print_tree()

    def knn(self, target, k=1): # target为需要查找k近邻的节点
        nearests = [] # 最近k个节点和最短距离的集合(最短距离,节点)
        for i in range(k):
            nearests.append([-1,None])
        def search(node):
            if node is None: # 空节点无需再找
                return
            diff_d = node.point[node.d] - target[node.d]
            if diff_d > 0: # target在node节点的划分维度上小于node,应该去左子空间
                search(node.left)
            else:
                search(node.right)
            sum = 0
            for i in range(len(node.point)):
                sum += (node.point[i]-target[i])**2
            cur_dis = sum**0.5 # 计算node到target的欧式距离
            # 下面这段代码插入knn节点,重点内容,请好好体会
            # nearests从左到右,到target的距离依次递增
            for i in range(len(nearests)):
                # 元素为空或者当前元素到target的最短距离大于node到target的最短距离
                if nearests[i][0] == -1 or nearests[i][0] > cur_dis:
                    nearests.insert(i,[cur_dis, node])
                    nearests.pop()# 保证nearests长度为k,去掉末尾的元素
                    break
            index = -1
            for i in range(len(nearests)): # 找到距target距离最远的元素,称为r吧
                if nearests[i][0] == -1:
                    index = i - 1
                    break
            # 下面这段代码也是重点
            # 这个if的意思是当以target为超球体的球心,r到target的最短距离为半径时
            # 如果该超球体与node用来划分的超平面有交集则可能在node的另一个我们在上面未查找的子空间中存在到target距离更小的节点
            if nearests[index][0] > abs(diff_d):
                if diff_d > 0: # 上面node去左子空间了,所以这里去右子空间
                    search(node.right)
                else:
                    search(node.left)
        search(self)
        return [[item[0],item[1].point]for item in nearests]


if __name__ == '__main__':
    points = [(2, 3), (4, 7), (5, 4), (7, 2), (8, 1), (9, 6)]
    # KDTree(points).print_tree()
    tree = KDTree(points)
    coord = (2.2, 3.2)
    nearest_nodes = tree.knn(coord, 3)
    print(nearest_nodes[-1][0])
    plt.scatter(coord[0], coord[1], s=(nearest_nodes[-1][0]*100)**2, alpha=0.5)
    X = [item[0] for item in points]
    Y = [item[1] for item in points]
    plt.scatter(X,Y, c='g')
    plt.scatter(coord[0],coord[1],c='r')
    plt.annotate(str(coord),xy=coord,xytext=(coord[0]+0.2,coord[1]+0.2))
    for item in nearest_nodes:
        node = item[1]
        plt.annotate(str(node),xy=node, xytext=(node[0]+0.3,node[1]-0.3))
    plt.title('knn-example')
    plt.grid()
    plt.show()

测试代码

查找在(2.2, 3.2)的3近邻点

输出结果

[[0.2828427124746193, (2, 3)], [2.912043955712207, (5, 4)], [4.2047592083257275, (4, 7)]]
4.2047592083257275

knn-example

图里的半径不是很准,意思意思就好


# python # ML