Условие
Реализуйте классификатор k-Nearest Neighbors без scikit-learn. На входе обучающая выборка X_train, y_train и k; на запросе x возвращайте предсказанный класс по большинству среди k ближайших соседей в евклидовой метрике.
Решение
Подход
- Для каждой
x_queryпосчитать расстояния до всехX_train. - Взять индексы k наименьших расстояний (
np.argpartition— быстрее, чем полная сортировка). - Среди соответствующих
y_trainпосчитать частоты классов; вернуть мажоритарный.
Реализация
from collections import Counter
import numpy as np
class KNN:
def __init__(self, k: int = 5):
self.k = k
def fit(self, X, y):
self.X_ = np.asarray(X, dtype=float)
self.y_ = np.asarray(y)
return self
def _predict_one(self, x):
# Евклидовы квадраты — sqrt не нужен для ранжирования
d2 = np.sum((self.X_ - x) ** 2, axis=1)
idx = np.argpartition(d2, self.k)[: self.k]
votes = Counter(self.y_[idx])
return votes.most_common(1)[0][0]
def predict(self, X):
X = np.asarray(X, dtype=float)
return np.array([self._predict_one(x) for x in X])Векторная версия предсказания пачки
def predict_batch(self, X):
X = np.asarray(X, dtype=float)
# (n_query, n_train)
d2 = np.sum((X[:, None, :] - self.X_[None, :, :]) ** 2, axis=2)
idx = np.argpartition(d2, self.k, axis=1)[:, : self.k]
neigh_labels = self.y_[idx]
# majority vote по строке
return np.array([Counter(row).most_common(1)[0][0] for row in neigh_labels])Подводные камни
- Масштаб признаков. kNN считает расстояния в исходном пространстве; признак «доход в долларах» доминирует над «возрастом». Стандартизация перед обучением обязательна.
sqrtне нужен для определения ближайших — экономьте время.- Ничьи при голосовании. Чётное
kили ничья голосов решается tie-break: либо снижаемk, либо смотрим на сумму расстояний. - Сложность predict: O(n_train × n_features) на каждый запрос — для миллионных датасетов нужен KD-tree / Ball-tree / Faiss / Annoy.
- Категориальные признаки. Евклид не работает; для них — Hamming / Gower distance или предварительное OHE.
Эталонный ответ
Расстояния через ((X - x)**2).sum(axis=1), индексы k наименьших через argpartition, мажоритарное голосование через Counter. Без sqrt, обязательно стандартизировать признаки перед обучением.