计算机视觉
图像处理

K-近邻分类算法KNN

文章目录

K-近邻(K-Nearest Neighbors, KNN)是一种很好理解的分类算法,简单说来就是从训练样本中找出K个与其最相近的样本,然后看这K个样本中哪个类别的样本多,则待判定的值(或说抽样)就属于这个类别。

KNN算法的步骤

  • 计算已知类别数据集中每个点与当前点的距离;
  • 选取与当前点距离最小的K个点;
  • 统计前K个点中每个类别的样本出现的频率;
  • 返回前K个点出现频率最高的类别作为当前点的预测分类。

OpenCV中使用CvKNearest

OpenCV中实现CvKNearest类可以实现简单的KNN训练和预测。
  1. int main()
  2. {
  3.     float labels[10] = {0,0,0,0,0,1,1,1,1,1};
  4.     Mat labelsMat(10, 1, CV_32FC1, labels);
  5.     cout<<labelsMat<<endl;
  6.     float trainingData[10][2];
  7.     srand(time(0));
  8.     for(int i=0;i<5;i++){
  9.         trainingData[i][0] = rand()%255+1;
  10.         trainingData[i][1] = rand()%255+1;
  11.         trainingData[i+5][0] = rand()%255+255;
  12.         trainingData[i+5][1] = rand()%255+255;
  13.     }
  14.     Mat trainingDataMat(10, 2, CV_32FC1, trainingData);
  15.     cout<<trainingDataMat<<endl;
  16.     CvKNearest knn;
  17.     knn.train(trainingDataMat,labelsMat,Mat(), false, 2 );
  18.     // Data for visual representation
  19.     int width = 512, height = 512;
  20.     Mat image = Mat::zeros(height, width, CV_8UC3);
  21.     Vec3b green(0,255,0), blue (255,0,0);
  22.     for (int i = 0; i < image.rows; ++i){
  23.         for (int j = 0; j < image.cols; ++j){
  24.             const Mat sampleMat = (Mat_<float>(1,2) << i,j);
  25.             Mat response;
  26.             float result = knn.find_nearest(sampleMat,1);
  27.             if (result !=0){
  28.                 image.at<Vec3b>(j, i)  = green;
  29.             }
  30.             else
  31.                 image.at<Vec3b>(j, i)  = blue;
  32.         }
  33.     }
  34.         // Show the training data
  35.         for(int i=0;i<5;i++){
  36.             circle( image, Point(trainingData[i][0],  trainingData[i][1]),
  37.                 5, Scalar(  0,   0,   0), -1, 8);
  38.             circle( image, Point(trainingData[i+5][0],  trainingData[i+5][1]),
  39.                 5, Scalar(255, 255, 255), -1, 8);
  40.         }
  41.         imshow(“KNN Simple Example”, image); // show it to the user
  42.         waitKey(10000);
  43. }

分类结果如下:

预测函数find_nearest()除了输入sample参数外还有些其他的参数:
  1. float CvKNearest::find_nearest(const Mat& samples, int k, Mat* results=0,
  2. const float** neighbors=0, Mat* neighborResponses=0, Mat* dist=0 )

即,samples 为样本数*特征数的浮点矩阵;K为寻找最近点的个数;results与预测结果;neibhbors为k*样本数的指针数组(输入为const,实在不知 为何如此设计);neighborResponse为样本数*k的每个样本K个近邻的输出值;dist为样本数*k的每个样本K个近邻的距离。

另一个例子

OpenCV refman也提供了一个类似的示例,使用CvMat格式的输入参数:
  1. int main( int argc, char** argv )
  2. {
  3.     const int K = 10;
  4.     int i, j, k, accuracy;
  5.     float response;
  6.     int train_sample_count = 100;
  7.     CvRNG rng_state = cvRNG(-1);
  8.     CvMat* trainData = cvCreateMat( train_sample_count, 2, CV_32FC1 );
  9.     CvMat* trainClasses = cvCreateMat( train_sample_count, 1, CV_32FC1 );
  10.     IplImage* img = cvCreateImage( cvSize( 500, 500 ), 8, 3 );
  11.     float _sample[2];
  12.     CvMat sample = cvMat( 1, 2, CV_32FC1, _sample );
  13.     cvZero( img );
  14.     CvMat trainData1, trainData2, trainClasses1, trainClasses2;
  15.     // form the training samples
  16.     cvGetRows( trainData, &trainData1, 0, train_sample_count/2 );
  17.     cvRandArr( &rng_state, &trainData1, CV_RAND_NORMAL, cvScalar(200,200), cvScalar(50,50) );
  18.     cvGetRows( trainData, &trainData2, train_sample_count/2, train_sample_count );
  19.     cvRandArr( &rng_state, &trainData2, CV_RAND_NORMAL, cvScalar(300,300), cvScalar(50,50) );
  20.     cvGetRows( trainClasses, &trainClasses1, 0, train_sample_count/2 );
  21.     cvSet( &trainClasses1, cvScalar(1) );
  22.     cvGetRows( trainClasses, &trainClasses2, train_sample_count/2, train_sample_count );
  23.     cvSet( &trainClasses2, cvScalar(2) );
  24.     // learn classifier
  25.     CvKNearest knn( trainData, trainClasses, 0, false, K );
  26.     CvMat* nearests = cvCreateMat( 1, K, CV_32FC1);
  27.     for( i = 0; i < img->height; i++ )
  28.     {
  29.         for( j = 0; j < img->width; j++ )
  30.         {
  31.             sample.data.fl[0] = (float)j;
  32.             sample.data.fl[1] = (float)i;
  33.             // estimate the response and get the neighbors’ labels
  34.             response = knn.find_nearest(&sample,K,0,0,nearests,0);
  35.             // compute the number of neighbors representing the majority
  36.             for( k = 0, accuracy = 0; k < K; k++ )
  37.             {
  38.                 if( nearests->data.fl[k] == response)
  39.                     accuracy++;
  40.             }
  41.             // highlight the pixel depending on the accuracy (or confidence)
  42.             cvSet2D( img, i, j, response == 1 ?
  43.                 (accuracy > 5 ? CV_RGB(180,0,0) : CV_RGB(180,120,0)) :
  44.                 (accuracy > 5 ? CV_RGB(0,180,0) : CV_RGB(120,120,0)) );
  45.         }
  46.     }
  47.     // display the original training samples
  48.     for( i = 0; i < train_sample_count/2; i++ )
  49.     {
  50.         CvPoint pt;
  51.         pt.x = cvRound(trainData1.data.fl[i*2]);
  52.         pt.y = cvRound(trainData1.data.fl[i*2+1]);
  53.         cvCircle( img, pt, 2, CV_RGB(255,0,0), CV_FILLED );
  54.         pt.x = cvRound(trainData2.data.fl[i*2]);
  55.         pt.y = cvRound(trainData2.data.fl[i*2+1]);
  56.         cvCircle( img, pt, 2, CV_RGB(0,255,0), CV_FILLED );
  57.     }
  58.     cvNamedWindow( “classifier result”, 1 );
  59.     cvShowImage( “classifier result”, img );
  60.     cvWaitKey(0);
  61.     cvReleaseMat( &trainClasses );
  62.     cvReleaseMat( &trainData );
  63.     return 0;
  64. }

分类结果:

KNN的思想很好理解,也非常容易实现,同时分类结果较高,对异常值不敏感。但计算复杂度较高,不适于大数据的分类问题。

转载注明来源:CV视觉网 » K-近邻分类算法KNN

分享到:更多 ()
扫描二维码,给作者 打赏
pay_weixinpay_weixin

请选择你看完该文章的感受:

0不错 0超赞 0无聊 0扯淡 0不解 0路过

评论 6

评论前必须登录!