KNN-K近邻算法

魔法师LQ

K近邻-KNN(K nearest neighbor)算法采用测量不同特征值之间的距离来进行分类。

1
2
3
4
5
6
7
import numpy as np
# 收集数据:无
# 准备数据:用以计算数值所需要的数据
def createDataset():
group = np.array([[0.0, 0.0], [0.0, 0.1], [1.0, 1.0], [1.0, 1.1]])
labels = np.array(['A', 'A', 'B', 'B'])
return group, labels
1
2
3
group, labels = createDataset()
print(group)
print(labels)
[[0.  0. ]
 [0.  0.1]
 [1.  1. ]
 [1.  1.1]]
['A' 'A' 'B' 'B']

实施K近邻算法:

  1. 计算已知类别数据集中点和当前点之间的距离;
  2. 按照距离递增次序排序;
  3. 选取与当前点距离最小的K个点;
  4. 确定前K个点在类别中出现的频率;
  5. 将出现频率最高的代表的类别作为当前点的预测类别。
1
2
3
4
5
6
7
8
9
10
def classify0(inX, dataSet, labels, k):
distances = np.sqrt(((inX-dataSet)**2).sum(axis=1))
sortedDistIndicies = distances.argsort()

classCount = {}
for i in range(k):
voteLabel = labels[sortedDistIndicies[i]]
classCount[voteLabel] = classCount.get(voteLabel, 0) + 1
sortedClassCount = sorted(classCount.items(), key=lambda s:s[1], reverse=True)
return sortedClassCount[0][0]
1
classify0([0.0, 0.2], group,labels, 3)
'A'

示例:在约会网站上使用K-近邻算法

  1. 收集数据:文本,下载
  2. 准备数据:使用Python解析文本文件
  3. 分析数据:使用Matplotlib画二维扩散图
  4. 训练算法:此步骤不适用于K-近邻算法
  5. 测试算法:使用提供的数据的部分作为测试数据。测试数据和非测试数据的区别在于:测试数据已经完成分类,如果预测类别和实际类别分类不同,则标记为一个错误。
  6. 产生简单的命令行程序,然后用户输入一些特征数据来判断对方是否是自己喜欢的类型。

准备数据:从文本文件中解析数据

假设将记录数据存储在文本文件中,每个样本占据一行,样本包括三个特征:

  • 每年乘坐交通工具的里程数
  • 玩电子游戏占日常时间的百分比
  • 每周消耗的饮料公升数
1
2
3
4
5
6
7
8
9
def file2matrix(filename):
returnMat = []
labelVec = []
with open(filename, 'r') as f:
for line in f.readlines():
listFromLine = line.strip().split('\t')
returnMat.append(listFromLine[:-1])
labelVec.append(int(listFromLine[-1]))
return np.array(returnMat, dtype=np.float32), labelVec
1
returnMat, labelVec = file2matrix('datingTestSet2.txt')
1
print(returnMat[:3])
[[4.092000e+04 8.326976e+00 9.539520e-01]
 [1.448800e+04 7.153469e+00 1.673904e+00]
 [2.605200e+04 1.441871e+00 8.051240e-01]]
1
print(labelVec[:5])
[3, 2, 1, 1, 1]

numpy数组类型不支持Python自带的数组类型,所以要注意不要使用错误的数据类型。

分析数据:使用Matplotlib创建散点图

1
2
3
4
5
6
7
8
9
import matplotlib.pyplot as plt
%matplotlib notebook

fig = plt.figure()
ax = fig.add_subplot(111)
ax.scatter(returnMat[:, 1], returnMat[:, 2], 15.0*np.array(labelVec), 15.0*np.array(labelVec))
ax.set_xlabel('Percentage of Time Spent Playing Video Game')
ax.set_ylabel('Liter of Drinks')
plt.show()
<IPython.core.display.Javascript object>

1
2
3
4
5
6
7
8
9
10
11
12
13
fig2 = plt.figure()
idx1 = np.where(np.array(labelVec)==1)
p1 = plt.scatter(returnMat[idx1, 0], returnMat[idx1, 1], marker='x', color='m', label='dislike', s=15)
idx2 = np.where(np.array(labelVec)==2)
p2 = plt.scatter(returnMat[idx2, 0], returnMat[idx2, 1], marker='+', color='c', label='not bad', s=20)
idx3 = np.where(np.array(labelVec)==3)
p3 = plt.scatter(returnMat[idx3, 0], returnMat[idx3, 1], marker='o', color='r', label='charming', s=30)


plt.xlabel('Mileage')
plt.ylabel('Percentage of Time Spent Playing Video Game')
plt.legend(loc='upper right')
plt.show()
<IPython.core.display.Javascript object>

准备数据:归一化数据

将数据范围统一到特定的取值区间,如0~1,-1~1之间。

1
newValue = (oldValue-min)/(max-min)
1
2
3
4
5
6
def autoNorm(dataSet):
minValues = dataSet.min(0)
maxValues = dataSet.max(0)
ranges = maxValues - minValues
normDataSet = (dataSet-minValues)/ranges
return normDataSet, ranges, minValues
1
autoNorm(returnMat)
(array([[0.44832537, 0.39805138, 0.5623336 ],
        [0.1587326 , 0.34195465, 0.9872441 ],
        [0.28542942, 0.06892523, 0.4744963 ],
        ...,
        [0.29115948, 0.5091029 , 0.51079494],
        [0.527111  , 0.4366545 , 0.42900482],
        [0.47940794, 0.37680906, 0.7857181 ]], dtype=float32),
 array([9.127300e+04, 2.091935e+01, 1.694361e+00], dtype=float32),
 array([0.      , 0.      , 0.001156], dtype=float32))

测试算法:作为完整测试程序检验分类器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def datingClassTest():
k = 3
ratio = 0.10
datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
normMat, ranges, minVals = autoNorm(datingDataMat)
m = normMat.shape[0]
numTestVecs = int(m*ratio)
errorCount = 0.0
for i in range(numTestVecs):
classifyResult = classify0(normMat[i, :], normMat[numTestVecs:m, :], datingLabels[numTestVecs:m], k)
print('Predict: ', classifyResult, ' Real:', datingLabels[i])
if classifyResult !=datingLabels[i]:
errorCount += 1.0
print('Total error rate: %f'%(errorCount/float(numTestVecs)))
1
datingClassTest()
Predict:  3  Real: 3
Predict:  2  Real: 2
Predict:  1  Real: 1
Predict:  1  Real: 1
Predict:  1  Real: 1
Predict:  1  Real: 1
Predict:  3  Real: 3
Predict:  3  Real: 3
Predict:  1  Real: 1
Predict:  3  Real: 3
Predict:  1  Real: 1
Predict:  1  Real: 1
Predict:  2  Real: 2
Predict:  1  Real: 1
Predict:  1  Real: 1
Predict:  1  Real: 1
Predict:  1  Real: 1
Predict:  1  Real: 1
Predict:  2  Real: 2
Predict:  3  Real: 3
Predict:  2  Real: 2
Predict:  1  Real: 1
Predict:  3  Real: 2
Predict:  3  Real: 3
Predict:  2  Real: 2
Predict:  3  Real: 3
Predict:  2  Real: 2
Predict:  3  Real: 3
Predict:  2  Real: 2
Predict:  1  Real: 1
Predict:  3  Real: 3
Predict:  1  Real: 1
Predict:  3  Real: 3
Predict:  1  Real: 1
Predict:  2  Real: 2
Predict:  1  Real: 1
Predict:  1  Real: 1
Predict:  2  Real: 2
Predict:  3  Real: 3
Predict:  3  Real: 3
Predict:  1  Real: 1
Predict:  2  Real: 2
Predict:  3  Real: 3
Predict:  3  Real: 3
Predict:  3  Real: 3
Predict:  1  Real: 1
Predict:  1  Real: 1
Predict:  1  Real: 1
Predict:  1  Real: 1
Predict:  2  Real: 2
Predict:  2  Real: 2
Predict:  1  Real: 1
Predict:  3  Real: 3
Predict:  2  Real: 2
Predict:  2  Real: 2
Predict:  2  Real: 2
Predict:  2  Real: 2
Predict:  3  Real: 3
Predict:  1  Real: 1
Predict:  2  Real: 2
Predict:  1  Real: 1
Predict:  2  Real: 2
Predict:  2  Real: 2
Predict:  2  Real: 2
Predict:  2  Real: 2
Predict:  2  Real: 2
Predict:  3  Real: 3
Predict:  2  Real: 2
Predict:  3  Real: 3
Predict:  1  Real: 1
Predict:  2  Real: 2
Predict:  3  Real: 3
Predict:  2  Real: 2
Predict:  2  Real: 2
Predict:  3  Real: 1
Predict:  3  Real: 3
Predict:  1  Real: 1
Predict:  1  Real: 1
Predict:  3  Real: 3
Predict:  3  Real: 3
Predict:  1  Real: 1
Predict:  2  Real: 2
Predict:  3  Real: 3
Predict:  3  Real: 1
Predict:  3  Real: 3
Predict:  1  Real: 1
Predict:  2  Real: 2
Predict:  2  Real: 2
Predict:  1  Real: 1
Predict:  1  Real: 1
Predict:  3  Real: 3
Predict:  2  Real: 3
Predict:  1  Real: 1
Predict:  2  Real: 2
Predict:  1  Real: 1
Predict:  3  Real: 3
Predict:  3  Real: 3
Predict:  2  Real: 2
Predict:  1  Real: 1
Predict:  3  Real: 1
Total error rate: 0.050000

使用算法:构建完整的可用系统

1
2
3
4
5
6
7
8
9
10
11
12
13
def classifyPerson():
# 类型
result_dict = {1:'not at all', 2:'in small doses', 3:'in large doses'}
# 输入新的候选者的关键指标
inM = float(input('Please input Mileage per Year: '))
inP = float(input('Please input Percentage of Time Spent Playing Computer Games: '))
inD = float(input('Please input Liter of Drinks: '))
inX = [inM, inP, inD]
# 使用K近邻算法进行预测
datingDataMat, datingLables = file2matrix('datingTestSet2.txt')
normMat, ranges, minVals = autoNorm(datingDataMat)
classPred = classify0(inX, datingDataMat, datingLables, 3)
print('You probably like this person: ', result_dict[classPred])
1
classifyPerson()
Please input Mileage per Year: 40920
Please input Percentage of Time Spent Playing Computer Games: 8.326
Please input Liter of Drinks: 0.9553
You probably like this person:  in large doses