原來機器學習那么簡單—KNN回歸
一、什么是K近鄰回歸?
K近鄰回歸(K-Nearest Neighbors Regression,簡稱KNN回歸)是一種簡單直觀的機器學習算法。KNN回歸通過尋找樣本空間中與目標點最接近的K個鄰居,利用這些鄰居的平均值或加權平均值來預測目標點的值。KNN回歸屬于非參數模型,因為它不對數據的分布做出假設,也不需要訓練過程。
二、K近鄰回歸的原理
KNN回歸的核心思想非常直觀,即“相似的樣本具有相似的輸出”。具體步驟如下:
- 計算距離:對于待預測的樣本點,計算其與訓練集中每一個樣本點之間的距離。常用的距離度量包括歐氏距離(Euclidean Distance)、曼哈頓距離(Manhattan Distance)等。歐氏距離的計算公式為:
- 選擇K個鄰居:根據計算得到的距離,選擇距離待預測樣本點最近的K個鄰居。
- 計算預測值:根據選中的K個鄰居的輸出值,計算待預測樣本點的輸出值。常用的方法包括簡單平均和加權平均。
如果是簡單平均,則預測值為K個鄰居的輸出值的算術平均:
三、K近鄰回歸的優缺點
優點:
- 簡單直觀:算法思想簡單,容易理解和實現。
- 無模型假設:KNN回歸不對數據的分布做任何假設,適用于各種數據分布。
- 高靈活性:由于無需訓練過程,KNN回歸可以處理在線學習問題,也可以隨時加入新的數據。
缺點:
- 計算復雜度高:對于大規模數據集,計算每個樣本點的距離代價較高,影響預測效率。
- 維度災難:隨著特征維數的增加,樣本之間的距離變得越來越難以區分,導致預測效果下降。
- 對異常值敏感:KNN回歸直接依賴于鄰居的輸出值,如果鄰居中存在異常值,可能會嚴重影響預測結果。
四、案例分析
在這一部分,我們還是使用加州房價數據集來演示如何應用K近鄰回歸算法進行預測。加州房價數據集包含了加州的街區信息,每個街區有多項特征,包括人口、收入、房屋年齡等。目標是根據這些特征預測該街區的房屋中位數價格。
- 數據加載與預處理:
- 加載加州房價數據集并進行標準化處理,確保所有特征都在相同的尺度上。
- 將數據集劃分為訓練集和測試集,比例為8:2。
- 模型訓練:
- 使用
KNeighborsRegressor
創建一個K近鄰回歸模型,選擇K=5
,即考慮最近的5個鄰居。 - 用訓練集的數據來訓練模型。
- 模型預測:
- 利用訓練好的模型對測試集進行預測,并計算均方誤差(MSE)作為模型性能的評估指標。
代碼實現:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsRegressor
from sklearn.metrics import mean_squared_error
# 加載加州房價數據集
california = fetch_california_housing()
X = california.data
y = california.target
# 數據標準化
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
# 劃分訓練集和測試集
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)
# 創建K近鄰回歸模型并訓練
knn = KNeighborsRegressor(n_neighbors=5)
knn.fit(X_train, y_train)
# 預測測試集
y_pred = knn.predict(X_test)
# 計算均方誤差
mse = mean_squared_error(y_test, y_pred)
print(f"測試集的均方誤差: {mse:.2f}")
# 可視化結果
plt.figure(figsize=(10, 6))
plt.scatter(y_test, y_pred, edgecolor='k', alpha=0.7)
plt.plot([min(y_test), max(y_test)], [min(y_test), max(y_test)], 'r--', lw=3)
plt.xlabel("真實房價")
plt.ylabel("預測房價")
plt.title("K近鄰回歸預測結果")
plt.show()
結果分析:
測試集的均方誤差: 0.43。繪制模型預測的房價與真實房價之間的關系圖如下:
五、總結
K近鄰回歸是一種簡單且易于理解的回歸算法,適合用于小規模數據集或需要在線更新模型的場景。然而,在使用KNN回歸時,需要考慮數據的維數和計算復雜度,并對異常值進行處理,以確保模型的預測效果。
本文轉載自寶寶數模AI,作者: BBSM ????
贊
收藏
回復
分享
微博
QQ
微信
舉報

回復
相關推薦