1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
| def knn_linear_scan(X_train, y_train, X_test, k=3, distance_fn='euclidean'): """ 参数: X_train (numpy.ndarray): 训练数据,形状为(n_samples, n_features) y_train (numpy.ndarray): 训练标签,形状为(n_samples,) X_test (numpy.ndarray): 测试数据,形状为(m_samples, n_features) k (int): 最近邻数量,默认为3 distance_fn (str): 距离函数,支持'euclidean'或'manhattan' 返回值: numpy.ndarray: 预测标签数组,形状为(m_samples,) """ if distance_fn == 'euclidean': dist_func = lambda a, b: np.sqrt(np.sum((a - b)**2, axis=1)) elif distance_fn == 'manhattan': dist_func = lambda a, b: np.sum(np.abs(a - b), axis=1) else: raise ValueError("只支持'euclidean'或'manhattan'距离")
predictions = [] for x in X_test: distances = dist_func(X_train, x) nearest = np.argpartition(distances, k)[:k] counts = np.bincount(y_train[nearest]) predictions.append(np.argmax(counts)) return np.array(predictions)
class KDNode: def __init__(self, point, label, axis, left=None, right=None): self.point = point self.label = label self.axis = axis self.left = left self.right = right
def build_kdtree(X, y, depth=0): """ 递归构建KD树 参数: X (numpy.ndarray): 训练数据 y (numpy.ndarray): 对应标签 depth (int): 当前深度 返回值: KDNode: 树的根节点 """ if len(X) == 0: return None
axis = depth % X.shape[1] indices = np.argsort(X[:, axis]) X_sorted, y_sorted = X[indices], y[indices]
mid = len(X) // 2 return KDNode( point=X_sorted[mid], label=y_sorted[mid], axis=axis, left=build_kdtree(X_sorted[:mid], y_sorted[:mid], depth+1), right=build_kdtree(X_sorted[mid+1:], y_sorted[mid+1:], depth+1) )
def kd_tree_search(root, query, k): """ KD树近邻搜索 参数: root (KDNode): KD树根节点 query (numpy.ndarray): 查询点 k (int): 最近邻数量 返回值: list: 包含k个最近邻的(label, distance)元组 """ heap = []
def recursive_search(node): if node is None: return
distance = np.sqrt(np.sum((node.point - query)**2)) if len(heap) < k: heapq.heappush(heap, (-distance, node)) elif distance < -heap[0][0]: heapq.heappushpop(heap, (-distance, node))
axis = node.axis if query[axis] < node.point[axis]: recursive_search(node.left) if (len(heap) < k) or (abs(node.point[axis] - query[axis]) < -heap[0][0]): recursive_search(node.right) else: recursive_search(node.right) if (len(heap) < k) or (abs(node.point[axis] - query[axis]) < -heap[0][0]): recursive_search(node.left)
recursive_search(root) return [(node.label, -dist) for dist, node in heap]
def knn_kd_tree(X_train, y_train, X_test, k=3): """ KD树KNN预测 参数: 参数同knn_linear_scan 返回值: numpy.ndarray: 预测标签数组 """ tree = build_kdtree(X_train.values, y_train.values) predictions = [] for x in X_test.values: neighbors = kd_tree_search(tree, x, k) labels = [label for label, _ in neighbors] counts = np.bincount(labels) predictions.append(np.argmax(counts) if len(counts) > 0 else 0) return np.array(predictions)
|