计算机视觉
图像处理

EM算法(对EM算法的简单理解,opencv自带EM sample学习)

因做实验的需要,最近在学习EM算法,算法介绍的资料网上是有不少,可是没有一篇深入浅出的介绍,算法公式太多,比较难懂,毕竟她是ML领域10大经典算法之一 ,且一般是结合GMM模型的参数估计来介绍EM的。看过不少EM的资料,现将自己对EM算法用稍微通俗点的文字写下来,当然你可以用GMM这个具体的例子来帮助理解。

  问题的提出:给定一些样本数据x,且知道该数据是有k个高斯混合产生的,现在要用给的样本数据x去估计模型的参数sida,即在该参数sida下产生x数据的概率最大。(其实就是个MLE估计)

  1. 原问题等价与求sida,使得满足max(logP((x/sida))),那么我们为什么不直接用MLE去估计呢?通过关于EM算法各种推导公式(我这里基本把这些公式都省略掉,因为介绍这方面的资料有不少)可以看出,对数里面有求和的项,说白了,就算用MLE的方法去做解不出来,因为各种求偏导什么的很难求。
  2. 所以在EM算法中有个假设,即我们不仅知道观测到的数据x,而且还知道它属于隐变量z的哪一类(GMM中,隐变量z表示各个单高斯模型)。此时原问题的求解等价于求sida,使得满足max(logP((x,z)/sida))
  3. 为什么2中就能用MLE解决呢,又通过查看EM算法各种公式推导可以看出,21的不同在与2中那些对数符号里面没有了求和项,所以各种求导方法等在此可以应用。
  4. 但是我们的z变量是隐含的,也就是说未知的,那么2中的MLE该怎么做呢?通过查找EM算法的公式推导过程可以看出,2中的求max(logP((x,z)/sida))中的sida可以等价与求max[Ez(logP(x,z)/sida)],即求logP((x,z)/sida)关于变量z的期望最大。
  5. 既然是求其关于z的期望,那么我们应该知道z的概率分布才行。比较幸运的是在EM体系中,关于z的分布也是很容易求得的,即z的后验分布P(z/x,sida)很容易求出来。
  6. E-step:首先随便取一组参数sida,求出5z的后验分布,同时求出logP((x,z)/sida)关于z的期望,即Ez(logp((x,z)/sida))
  7. M-step:前面已经讲到,6中的期望最大用MLE很容易解决,所以M-step时采用MLE求得新的参数sida,又从前面的介绍可知,6中的期望最大时的参数等价于原问题的求解的参数。
  8. 返回67之间迭代,直到满足logP((x,z)/sida)基本不再变化。

已经用通俗的语言简单的介绍了下EM算法,在这一节中就采用Opencv自带的一个EM sample来学习下opencvEM 算法类的使用,顺便也体验下EM 算法的实际应用。

环境:Ubuntu12.04+Qt4.8.2+QtCreator2.5+opencv2.4.2

  在这里需要使用2个与EM算法有关的类,即CvEMCvEMParams,这2个类在opencv2.4.2已经放入legacy文件夹中了,说明不久就会被淘汰掉,因为在未来的opencv版本中,将采用Algorithm这个公共类来统一接口。不过CvEMCvEMParams的使用与其类似,且可以熟悉EM算法的使用流程。

  需要注意的是这2个类虽然是与EM算法有关,可是只能解决GMM问题,比较局限。也许这是将其放在legacy中的原因吧。

 

  实验流程:

  首先产生需要聚类的样本数据,我这里采用的是9个混合的二维高斯分布,所以需要被聚类成9类,这些GMM排成3*3的格式,每一格25个点,共225个训练样本。在软件中显示出样本点的分布。

  用类EMCvEMParams初始化emem_params对象。

  设置EM参数类em_params的各个参数,这里的均值、权值、方差的初始化采用的是kmeans聚类得到的,em_params的参数中需要特别指定的是所聚类类别N(这里等于9.

  用这255个数据进行训练EM模型,采用的是CvEM类方法train()函数。

  把窗口大小500*500内的每个点用训练出来的EM模型进行预测,将预测结果用不同的颜色在软件中画出来。

  把训练过程中样本的类别标签(程序中保存在label中)在图像中显示出来。

 

  实验结果:

  软件界面图:

 

  按下Gnenrate Data按钮后显示如下:

 

  按下EM Cluster按钮后显示如下:

 

实验代码:

mainwindow.h:

#ifndef MAINWINDOW_H
#define MAINWINDOW_H

#include <QMainWindow>
//#include <vector>
#include <opencv2/core/core.hpp>
#include <opencv2/ml/ml.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/legacy/legacy.hpp>

using namespace cv;
using namespace std;
//using std::vector;

namespace Ui {
class MainWindow;
}

class MainWindow : public QMainWindow
{
    Q_OBJECT
    
public:
    explicit MainWindow(QWidget *parent = 0);
    ~MainWindow();

    vector<Scalar> colors;
    
private slots:

    void on_closeButton_clicked();

    void on_generateButton_clicked();

    void on_clusterButton_clicked();

private:
    Ui::MainWindow *ui;

    int nsamples;
    int N, N1;
    Mat img, img1;
    Mat samples, sample_predict;
    Mat labels;
    CvEM em;
    CvEMParams em_params;
};

#endif // MAINWINDOW_H

 

mainwindow.cpp:

#include "mainwindow.h"
#include "ui_mainwindow.h"
#include <QImage>

MainWindow::MainWindow(QWidget *parent) :
    QMainWindow(parent),
    ui(new Ui::MainWindow)
{
    ui->setupUi(this);
    N = 9;
    N1 = (int)sqrt(double(N));
    nsamples = 225;
    img = Mat( Size(500, 500), CV_8UC3 );

    colors.resize(N);
    colors.at(0) = Scalar(0, 255, 255);
    colors.at(1) = Scalar(255, 0, 255);
    colors.at(2) = Scalar(255, 255, 0);
    colors.at(3) = Scalar(255, 0, 0);
    colors.at(4) = Scalar(0, 255, 0);
    colors.at(5) = Scalar(0, 0, 255);
    colors.at(6) = Scalar(255, 100, 100);
    colors.at(7) = Scalar(100, 255, 100);
    colors.at(8) = Scalar(100, 100, 255);

}

MainWindow::~MainWindow()
{
    delete ui;
}


void MainWindow::on_closeButton_clicked()
{
    close();
}

void MainWindow::on_generateButton_clicked()
{
    samples = Mat( nsamples, 2, CV_32FC1);//用来存储产生的二维随机点
    samples = samples.reshape( 2, 0 );//转换成2通道的矩阵,reshape函数只适应而2维图像

    //初始化样本
    for( int i = 0; i < N; i++ )
        {
            Mat sub_samples = samples.rowRange( i*nsamples/N, (i+1)*nsamples/N );
            Scalar mean( (i%N1+1)*img.rows/(N1+1), (i/N1+1)*img.rows/(N1+1));
            Scalar var( 30, 30 );
            randn( sub_samples, mean, var );
        }
    samples = samples.reshape( 1, 0 );

    //显示样本数据
    for( int j = 0; j < nsamples; j++ )
    {
        Point gene_sample;
        gene_sample.x = cvRound(samples.at<float>(j, 0));
        gene_sample.y = cvRound(samples.at<float>(j, 1));
        circle( img, gene_sample, 1, Scalar(0, 255, 250), 1, 8 );
    }
    cvtColor( img, img, CV_BGR2RGB );

    /*Qt中处理图像有4个类,分别为QImage,QPixmap,QBitmap,QPicture.其中QPixmap专门负责在屏幕上显示图片
    的,QImage专门负责和I/O方面的,QBitmap是从QPixmap中继承来的,只负责一个通道的图像处理,QPicture是
    专门用来负责画图的*/
    QImage qimg = QImage( img.data, img.cols, img.rows, QImage::Format_RGB888 );
    //setPixmap为QLabel发出的公共信号,fromImage函数为将图片转换程QPixmap的格式
    ui->imgLabel->setPixmap( QPixmap::fromImage( qimg ) );
}

void MainWindow::on_clusterButton_clicked()
{
    //给EM算法参赛赋值,均值,方差和权值采用kmeans初步聚类得到
    em_params.means = NULL;
    em_params.covs = NULL;
    em_params.weights = NULL;
    em_params.nclusters = N;
    em_params.start_step = CvEM::START_AUTO_STEP;
    em_params.cov_mat_type = CvEM::COV_MAT_SPHERICAL;
    //达到最大迭代次数或者迭代误差小到一定值,应该有系统默认的值
    em_params.term_crit.type = CV_TERMCRIT_ITER | CV_TERMCRIT_EPS;

    cvtColor( img, img, CV_RGB2BGR );

    //EM算法训练过程
    em.train( samples, Mat(), em_params, &labels );

    //画出背景图
    sample_predict = Mat( 1, 2, CV_32FC1 );
    for( int i = 0; i < img.rows; i++ )
        for( int j = 0; j < img.cols; j++ )
            {
                sample_predict.at<float>(0) = (float)i;
                sample_predict.at<float>(1) = (float)j;
                int value = cvRound(em.predict( sample_predict ));//返回的value为预测类标签
                circle( img, Point(i, j), 1, 0.1*colors.at(value), 1, 8 );
            }

    //画出样本点的聚类情况
    for( int n = 0; n < nsamples; n++ )
        circle( img, Point(cvRound(samples.at<float>(n, 0)), cvRound(samples.at<float>(n, 1))),
                1, colors.at( labels.at<int>(n)), 1, 8 );//因为此时labels保存的是类标签(1~N),为整型

    //显示图像
    cvtColor( img, img, CV_BGR2RGB );
    QImage qimg = QImage( img.data, img.cols, img.rows, QImage::Format_RGB888 );
    ui->imgLabel->setPixmap( QPixmap::fromImage(qimg) );


}

main.cpp:

#include <QApplication>
#include "mainwindow.h"

int main(int argc, char *argv[])
{
    QApplication a(argc, argv);
    MainWindow w;
    w.show();
    
    return a.exec();
}

实验总结:

要学会数据点产生的类似方法,特别是reshape函数的使用方法。

要学会用STL的vector,这个容器要比数组方便很多。

要多学点C++的编程思想。

附录:工程code下载地址

转载注明来源:CV视觉网 » EM算法(对EM算法的简单理解,opencv自带EM sample学习)

分享到:更多 ()
扫描二维码,给作者 打赏
pay_weixinpay_weixin

请选择你看完该文章的感受:

0不错 0超赞 0无聊 0扯淡 0不解 0路过

评论 3

评论前必须登录!