K近邻算法

Algorithm: KNN

Input: 训练数据集

T={(x1,y1),(x2,y2),,(xN,yN)}\begin{array}{c} T=\left\{(x_{1}, y_{1}), (x_{2}, y_{2}), \cdots, (x_{N}, y_{N})\right\} \end{array}

其中,xiXRnx_{i} \in \mathcal{X} \subseteq \mathcal{R}^{n}为实例的特征向量,yiY={c1,c2,,cK}y_{i} \in \mathcal{Y} = \{c_{1},c_{2},\cdots,c_{K}\}为实例的类别,i=1,2,3,,Ni=1,2,3, \cdots, N

Output: 实例xx所属类yy

  1. 根据给定的距离度量,在训练集TT中找出与xx最近邻的kk个点,涵盖这kk个点的xx的邻域记为Nk(x)N_{k}(x)
  2. Nk(x)N_{k}(x)中根据分类决策规则(如多数表决)决定xx的类别yy

y=argmaxcjxiNk(x)I(yi=cj),i=1,2,,N;j=1,2,,K\begin{array}{c} y=\arg \max _{c_{j}} \sum_{x_{i} \in N_{k}(x)} I\left(y_{i}=c_{j}\right), i=1,2, \cdots, N ; j=1,2, \cdots, K \end{array}

KNN算法的实现

不依靠任何现有框架的KNN实现代码由DeepSeek R1生成。

算法实现的前期工作:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# %%
import numpy as np
import heapq
from sklearn.datasets import fetch_openml
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

# 导入MNIST数据集并预处理
mnist = fetch_openml("mnist_784")
X, y = mnist["data"], mnist["target"].astype(int)
X = X / 255.0 # 归一化到[0,1]

# 划分训练集和测试集
X_train, X_test = X[:60000], X[60000:]
y_train, y_test = y[:60000], y[60000:]

#PCA处理数据
pca = PCA(n_components=0.95)

# 标准化数据
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# 对数据进行降维
X_train_pca = pca.fit_transform(X_train_scaled)
X_test_pca = pca.transform(X_test_scaled)

KNN的实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def knn_linear_scan(X_train, y_train, X_test, k=3, distance_fn='euclidean'):
"""
参数:
X_train (numpy.ndarray): 训练数据,形状为(n_samples, n_features)
y_train (numpy.ndarray): 训练标签,形状为(n_samples,)
X_test (numpy.ndarray): 测试数据,形状为(m_samples, n_features)
k (int): 最近邻数量,默认为3
distance_fn (str): 距离函数,支持'euclidean'或'manhattan'

返回值:
numpy.ndarray: 预测标签数组,形状为(m_samples,)
"""
if distance_fn == 'euclidean':
dist_func = lambda a, b: np.sqrt(np.sum((a - b)**2, axis=1))
elif distance_fn == 'manhattan':
dist_func = lambda a, b: np.sum(np.abs(a - b), axis=1)
else:
raise ValueError("只支持'euclidean'或'manhattan'距离")

predictions = []
for x in X_test:
distances = dist_func(X_train, x)
nearest = np.argpartition(distances, k)[:k]
counts = np.bincount(y_train[nearest])
predictions.append(np.argmax(counts))
return np.array(predictions)

# ------------------------- KD树实现 -------------------------
class KDNode:
def __init__(self, point, label, axis, left=None, right=None):
self.point = point # 当前节点数据点
self.label = label # 对应标签
self.axis = axis # 分割轴
self.left = left # 左子树
self.right = right # 右子树

def build_kdtree(X, y, depth=0):
"""
递归构建KD树
参数:
X (numpy.ndarray): 训练数据
y (numpy.ndarray): 对应标签
depth (int): 当前深度

返回值:
KDNode: 树的根节点
"""
if len(X) == 0:
return None

axis = depth % X.shape[1] # 按深度循环选择轴
indices = np.argsort(X[:, axis])
X_sorted, y_sorted = X[indices], y[indices]

mid = len(X) // 2
return KDNode(
point=X_sorted[mid],
label=y_sorted[mid],
axis=axis,
left=build_kdtree(X_sorted[:mid], y_sorted[:mid], depth+1),
right=build_kdtree(X_sorted[mid+1:], y_sorted[mid+1:], depth+1)
)

def kd_tree_search(root, query, k):
"""
KD树近邻搜索
参数:
root (KDNode): KD树根节点
query (numpy.ndarray): 查询点
k (int): 最近邻数量

返回值:
list: 包含k个最近邻的(label, distance)元组
"""
heap = []

def recursive_search(node):
if node is None:
return

# 计算当前节点距离
distance = np.sqrt(np.sum((node.point - query)**2))

# 维护最大堆保持k个最小元素
if len(heap) < k:
heapq.heappush(heap, (-distance, node))
elif distance < -heap[0][0]:
heapq.heappushpop(heap, (-distance, node))

# 递归搜索子树
axis = node.axis
if query[axis] < node.point[axis]:
recursive_search(node.left)
if (len(heap) < k) or (abs(node.point[axis] - query[axis]) < -heap[0][0]):
recursive_search(node.right)
else:
recursive_search(node.right)
if (len(heap) < k) or (abs(node.point[axis] - query[axis]) < -heap[0][0]):
recursive_search(node.left)

recursive_search(root)
return [(node.label, -dist) for dist, node in heap]

def knn_kd_tree(X_train, y_train, X_test, k=3):
"""
KD树KNN预测
参数:
参数同knn_linear_scan

返回值:
numpy.ndarray: 预测标签数组
"""
tree = build_kdtree(X_train.values, y_train.values)
predictions = []
for x in X_test.values:
neighbors = kd_tree_search(tree, x, k)
labels = [label for label, _ in neighbors]
counts = np.bincount(labels)
predictions.append(np.argmax(counts) if len(counts) > 0 else 0)
return np.array(predictions)

将MNIST数据集进行训练得到以下内容:

1
2
3
4
5
6
7
8
# ------------------------- 使用示例 -------------------------
# 使用前20000个样本加速演示
sample = 60000
test = 100
X_sample = X_train.iloc[:sample]
y_sample = y_train.iloc[:sample]
X_test_sample = X_test.iloc[:test]
y_test_sample = y_test.iloc[:test]
1
2
3
# 线性扫描测试
linear_pred = knn_linear_scan(X_sample.values, y_sample.values, X_test_sample.values, k=5)
print("准确率:", np.mean(linear_pred == y_test.iloc[:100].values))

运行时间:35.6s

准确率: 0.99

1
2
3
# KD树测试
kd_pred = knn_kd_tree(X_sample, y_sample, X_test_sample, k=5)
print("准确率:", np.mean(kd_pred == y_test.iloc[:100].values))

运行时间:1m16.5s

准确率: 0.99

1
2
3
4
5
6
7
from sklearn.neighbors import KNeighborsClassifier

knn_ls = KNeighborsClassifier(n_neighbors=5,algorithm='brute')
knn_ls.fit(X_sample, y_sample)
y_pred = knn_ls.predict(X_test_sample)
print("预测结果:", y_pred)
print("准确率:", knn_ls.score(X_test_sample, y_test_sample))

运行时间:1.2s

预测结果: [7 2 1 0 4 1 4 9 5 9 0 6 9 0 1 5 9 7 3 4 9 6 6 5 4 0 7 4 0 1 3 1 3 0 7 2 7
1 2 1 1 7 4 2 3 5 1 2 4 4 6 3 5 5 6 0 4 1 9 5 7 8 9 3 7 4 6 4 3 0 7 0 2 9
1 7 3 2 9 7 7 6 2 7 8 4 7 3 6 1 3 6 9 3 1 4 1 7 6 9]

准确率: 0.99

1
2
3
4
5
knn_kd = KNeighborsClassifier(n_neighbors=5, algorithm='kd_tree')
knn_kd.fit(X_sample, y_sample)
y_pred_kd = knn_kd.predict(X_test_sample)
print("kd树预测结果:", y_pred_kd)
print("kd树预测准确率:", knn_kd.score(X_test_sample, y_test_sample))

运行时间:25.1s

kd树预测结果: [7 2 1 0 4 1 4 9 5 9 0 6 9 0 1 5 9 7 3 4 9 6 6 5 4 0 7 4 0 1 3 1 3 0 7 2 7
1 2 1 1 7 4 2 3 5 1 2 4 4 6 3 5 5 6 0 4 1 9 5 7 8 9 3 7 4 6 4 3 0 7 0 2 9
1 7 3 2 9 7 7 6 2 7 8 4 7 3 6 1 3 6 9 3 1 4 1 7 6 9]

kd树预测准确率: 0.99

实验结果的分析

机器学习方法 李航著中,我们知道kd树可以提高k近邻搜索的效率,但在本实验中,我们发现kd树的效率并没有线性扫描的效率高。

MNIST数据集的维度为784个,哪怕是我们使用PCA处理后的数据,维度仍然高达331个,依旧会产生“维度灾难(curse of dimensionality)”