“k最近邻算法(k-Nearest Neighbor,kNN)分类算法是机器学习算法中最简单的算法之一,所谓k最近邻,就是k个最近的邻居的意思,说的是每个样本都可以用它最接近的k个邻居来代表。
kNN算法的核心思想是如果一个样本在特征空间中的k个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。”——kNN算法_百科
一、kNN算法基础
kNN算法计算步骤:
- 获取已知类别的训练数据,将数据的特征进行标准化处理,以排除取值大的特征和分类变量对最终距离计算的干扰,常用的标准化方法如下:
极差标准化:
$$X_{new}=\frac{X-min(X)}{max(X)-min(X)} $$
中心标准化:
$$X_{new}=\frac{X-\mu}{\sigma}$$
生成哑变量:
$$male=
\begin{cases}
1\ \ if\ X=male\
0\ \ otherwise\
\end{cases} $$ - 计算预测数据与各个训练数据之间的距离,常用的距离有欧式距离和Manhattan距离(又叫Block距离),连续变量多的时候用欧式距离,分类变量多的时候用曼哈顿距离,公式如下:
欧式距离:
$$d(x,y)=\sqrt{\sum_{i=0}^N(x_i-y_i)^2}$$
曼哈顿距离:
$$d(x,y)=\sum_{i=0}^N|x_i-y_i|$$ - 按照距离的大小按顺序进行排序;
- 选取距离最小的k个点;
- 确定选取的k个点所在各类别的频率;
- 选择选取的k个点中频率最高的类别作为预测数据的预测类别。
二、kNN算法R语言实现
获取训练集,使用的训练集数据是CDA数据挖掘课程的一个数据集,数据集记录了收入、学历、受欢迎程度、资产水平及相亲的结果。
orgData<-read.csv("date_data2.csv")
将目标变量和自变量分开存储,为之后自变量标准化做准备。
y<-orgData[,c("Dated")]
x<-orgData[,c(1,2,3,4)]
检查缺失值,经检查,没有缺失值问题
summary(x)
写个极差标准化的函数
normalize <- function(x) {
return((x - min(x)) / (max(x) - min(x)))
}
对自变量进行极差标准化
x<-as.data.frame(lapply(x, normalize))
合并自变量和因变量
data<-cbind(y,x)
将目标变量转变为因子格式
data$y<-as.factor(data$y)
构建训练集和测试集,所用的knn模型需要把自变量因变量分开输入
select<-sample(1:nrow(data),length(data$y)*0.7)
train=data[select,-1]#选取70%做训练集
test=data[-select,-1]#另外30%做测试集
train.y=data[select,1]#训练集目标分类
test.y=data[-select,1]#测试集目标分类
使用kNN算法预测测试集数据,这里默认使用的是欧式距离,trian是训练集,test是测试集,cl是训练集实际分类,k取周围k个最临近的样本,prob=TRUE打印选中分类的概率,use.all=TRUE使用所有小于kth距离的样本用来预测。
y_hat<-knn(train = train,test = test,cl=train.y,k=10,prob=TRUE,use.all=TRUE)
测试不同k取值下的模型效果,并打印出来不同取值下的accuracy、Recall和Precision
ROC<-data.frame()
for (i in seq(from =1,to =40,by =1)){
y_hat<-knn(train = train,test = test,cl=train.y,k=i)
require(caret)
con=confusionMatrix(y_hat,test.y,positive='1')
accuracy.knn<-con$overall[c('Accuracy')] #准确率
recall.knn<-con$byClass[c('Sensitivity')]#召回率
precision.knn<-con$byClass[c('Pos Pred Value')]#精确率
out<-data.frame(i,accuracy.knn,recall.knn,precision.knn)
ROC<-rbind(ROC,out)
}
names(ROC)<-c("n","accuracy","Recall","Precision")
View(ROC)
结果如下:
画出不同k取值下的模型效果
plot(ROC$Recall~ROC$n,type='l',ylim=c(0.7,1))
lines(ROC$n,ROC$Precision,col=c("red"))
legend("bottomright", legend=c("Recall", "Precision"), col=c("black", "red"), lwd=2)