计算机视觉
图像处理

机器学习(四)高斯混合模型

高斯混合算法是EM算法的一个典型的应用,EM算法的推导过程这里不打算详解,直接讲GMM算法的实现。之前做图像分割grab cut 算法的时候,只知道把opencv中的高斯混合模型代码复制下来,然后封装成类使用,学的比较浅。结果没过几天发现高斯混合算法又忘了差不多了,于是用matlab去亲自写过一遍,终于发现了高斯混合模型的奥义。我的理解是高斯混合模型其实是进化版的k均值算法,因此学习高斯混合模型,最好还是把k均值算法写过一遍。高斯混合与k均值的本质区别在于权值问题,k均值采用的是均匀权值,而高斯混合的权值需要根据高斯模型的概率进行确定。

开始学习高斯混合模型,需要先简单复习一下单高斯模型的参数估计方法,描述一个高斯模型其实就是要计算它的均值、协方差矩阵(一维空间为方差,二维以上称之为协方差矩阵):

假设有数据集X={x1,x2,x3……,xn},那么用这些数据来估计单高斯模型参数的计算公式为

OK,开始写代码前,先用matlab生成数据集,然后在进行聚类:

利用matlab的生成高斯模型数据集X:

  1. mu = [2 3];
  2. SIGMA = [1 0; 0 2];
  3. r1 = mvnrnd(mu,SIGMA,1000);
  4. plot(r1(:,1),r1(:,2),‘r+’);

 

然后利用上面的估计方法计算均值,和协方差是否满足均值为[2 3],协方差为[1 0; 0 2];测试代码如下,r2、covmat即为计算结果

  1. [m n]=size(r1);
  2. center=sum(r1)./m;
  3. r2(:,1)=r1(:,1)-center(1);
  4. r2(:,2)=r1(:,2)-center(2);
  5. covmat=1/m*r2’*r2;

 

先把单高斯模型的函数写好,因为高斯混合模型是它的进化版,计算高斯混合模型过程中需要调用单高斯模型参数估计,写好代码后面才不会乱掉。开始高斯混合建模之前,我先用matlab生成一个测试数据集data,如下图,然后再进行算法测试。

生成数据集代码如下:

  1. %生成测试数据
  2. mu = [2 3];%测试数据1
  3. SIGMA = [1 0; 0 2];
  4. r1 = mvnrnd(mu,SIGMA,100);
  5. plot(r1(:,1),r1(:,2),‘.’);
  6. hold on;
  7. mu = [10 10];%测试数据2
  8. SIGMA = [ 1 0; 0 2];
  9. r2 = mvnrnd(mu,SIGMA,100);
  10. plot(r2(:,1),r2(:,2),‘.’);
  11. mu = [5 8];%测试数据3
  12. SIGMA = [ 1 0; 0 2];
  13. r3= mvnrnd(mu,SIGMA,100);
  14. plot(r3(:,1),r3(:,2),‘.’);
  15. data=[r1;r2;r3];

 

ok,数据生成完毕,接着我们正式开始高斯混合算法解析,先看一下高斯混合模型的建模求参步骤:

高斯混合模型的求解,说得简单一点就是要求解高斯模型中的均值与协方差,现在我们要把上述的数据分成3类,那么我们就是要求解3个均值及其对应的3个协方差矩阵。先讲一下总体步骤,高斯混合模型包含3个步骤:

a.初始化各个高斯模型的参数,及每个高斯模型的权重;

b.根据各个高斯模型的参数及其权重,计算每个点属于各个高斯模型的权重,计算公式为:

其中:,Wj是每个高斯模型在这个模型所占用得权重。这个公式说的简单一点就是每个高斯模型的权重与其概率的乘积,这样计算出来就相当于每个高斯模型在每个数据点中的所占用的比例。

c.更新各个高斯模型的均值与方差,计算公式如下:

d.更新各个高斯模型的总权重,计算公式如下:

其实第c、d两个步骤,无所谓顺序,你完全可以总权重更新放在各个模型参数更新之前。迭代过程就是b、c、d三个步骤进行更新就可以了。OK,接着结合上面的公式写一写代码。

(1)初始化高斯模型参数。

这一步初始化,在实际应用中一般是先通过k均值算法进行初始聚类,然后根据聚类结果进行计算初始化参数。不过这里我为了测试,我们选择随机初始化,这样才能看出GMM算法到底能不能实现聚类。

我这里各个高斯模型初始均值(中心)的初始化方法选择跟k均值的初始化方法一样,也就是随机选择k个点位置作为k个高斯模型的初始均值。然后协方差矩阵的初始化,我选择单位矩阵,具体代码如下:

  1. [m n]=size(data);
  2. kn=3;
  3. countflag=zeros(1,kn);
  4. tdata=cell(1,kn);%建立3个空矩阵
  5. mu=cell(1,kn);%建立3个空矩阵
  6. sigma=cell(1,kn);%建立3个空矩阵
  7. %方案2 随机初始化参数
  8. for i=1:kn
  9.     mu{1,i}=data(i*10,:);
  10.     sigma{1,i}=eye(2,2);
  11.     weightp(i)=1/kn;
  12. end

 

(2)计算各个模型在各个点的权重值

这一步是计算每个数据点属于各个高斯混合的概率,说白了就是计算权值:

  1. pro_ij=zeros(m,kn);%存储每个点属于每个类的概率
  2. for i=1:m
  3.     sumpk=0;
  4.     for j=1:kn
  5.         pk(j)=weightp(j)*GSMPro(mu{1,j},sigma{1,j},data(i,:));
  6.         sumpk=sumpk+pk(j);
  7.     end
  8.     for j=1:kn
  9.         pro_ij(i,j)=pk(j)/sumpk;
  10.     end
  11. end

 

(3)步骤c 更新参数

  1. for j=1:kn
  2.      [mu{1,j},sigma{1,j}]=WeightGSM(data,pro_ij(:,j));
  3.  end

 

(4)步骤d 更新各个模型的总权重

  1. for j=1:kn
  2.       weightp(j)=sum(pro_ij(:,j))/m;
  3.   end

 

然后把步骤2、3、4的代码放在循环语句中进行迭代就ok了。最后贴一下整份代码:

1、脚本文件:

  1. close all;
  2. clear;
  3. clc;
  4. %生成测试数据
  5. mu = [2 3];%测试数据1
  6. SIGMA = [1 0; 0 2];
  7. r1 = mvnrnd(mu,SIGMA,100);
  8. plot(r1(:,1),r1(:,2),‘.’);
  9. hold on;
  10. mu = [10 10];%测试数据2
  11. SIGMA = [ 1 0; 0 2];
  12. r2 = mvnrnd(mu,SIGMA,100);
  13. plot(r2(:,1),r2(:,2),‘.’);
  14. mu = [5 8];%测试数据3
  15. SIGMA = [ 1 0; 0 2];
  16. r3= mvnrnd(mu,SIGMA,100);
  17. plot(r3(:,1),r3(:,2),‘.’);
  18. data=[r1;r2;r3];
  19. [m n]=size(data);
  20. kn=3;
  21. countflag=zeros(1,kn);
  22. tdata=cell(1,kn);%建立10个空矩阵
  23. mu=cell(1,kn);%建立10个空矩阵
  24. sigma=cell(1,kn);%建立10个空矩阵
  25. % 方案1 初始化采用kmeans,做参数的初步估计
  26. % Idx=kmeans(data,kn);
  27. % figure(2);%绘制初始化结果
  28. % hold on;
  29. for i=1:m
  30. %     if Idx(i)==1
  31. %         plot(data(i,1),data(i,2),‘.y’);
  32. %     elseif Idx(i)==2
  33. %          plot(data(i,1),data(i,2),‘.b’);
  34. %     end
  35. % end
  36. for i=1:m
  37. %    tdata{1,Idx(i)}=[tdata{1,Idx(i)};data(i,:)];
  38. % end
  39. for i=1:kn
  40. %     [mu{1,i},sigma{1,i}]=GSMData(tdata{1,i});
  41. % end
  42. for i=1:kn
  43. %     [trow,tcol]=size(tdata{1,i});
  44. %     weightp(i)=trow/m;
  45. % end
  46. %方案2 随机初始化
  47. for i=1:kn
  48.     mu{1,i}=data(i*10,:);
  49.     sigma{1,i}=eye(2,2);
  50.     weightp(i)=1/kn;
  51. end
  52. it=1;
  53. while it<1000
  54.     %E步 计算每个点处于每个类的概率
  55.     pro_ij=zeros(m,kn);%存储每个点属于每个类的概率
  56.     for i=1:m
  57.         sumpk=0;
  58.         for j=1:kn
  59.             pk(j)=weightp(j)*GSMPro(mu{1,j},sigma{1,j},data(i,:));
  60.             sumpk=sumpk+pk(j);
  61.         end
  62.         for j=1:kn
  63.             pro_ij(i,j)=pk(j)/sumpk;
  64.         end
  65.     end
  66.     %M步
  67.     for j=1:kn
  68.         [mu{1,j},sigma{1,j}]=WeightGSM(data,pro_ij(:,j));
  69.     end
  70.     %更新权值
  71.     for j=1:kn
  72.         weightp(j)=sum(pro_ij(:,j))/m;
  73.     end
  74.     sumw=sum(weightp);
  75.     it=it+1;
  76. end
  77. for i=1:m
  78.     [value index]=max(pro_ij(i,:));
  79.     Idx(i)=index;
  80. end
  81. figure(2);
  82. hold on;
  83. for i=1:m
  84.     if Idx(i)==1
  85.         plot(data(i,1),data(i,2),‘.y’);
  86.     elseif Idx(i)==2
  87.          plot(data(i,1),data(i,2),‘.b’);
  88.     elseif Idx(i)==3
  89.          plot(data(i,1),data(i,2),‘.r’);
  90.     end
  91. end
  92. % figure(3);
  93. % %px=gmmstd(data,3);
  94. for i=1:m
  95. %     [value index]=max(px(i,:));
  96. %     Idx(i)=index;
  97. % end
  98. % hold on;
  99. for i=1:m
  100. %     if Idx(i)==1
  101. %         plot(data(i,1),data(i,2),‘.y’);
  102. %     elseif Idx(i)==2
  103. %          plot(data(i,1),data(i,2),‘.b’);
  104. %     elseif Idx(i)==3
  105. %          plot(data(i,1),data(i,2),‘.r’);
  106. %     end
  107. % end
  108. %单高斯模型参数估计
  109. % [m n]=size(r1);
  110. % center=sum(r1)./m;
  111. % r2(:,1)=r1(:,1)-center(1);
  112. % r2(:,2)=r1(:,2)-center(2);
  113. % covmat=1/m*r2’*r2;

 

2、相关函数

  1. function [ mu ,sigma ] = WeightGSM(data,weight)
  2.     %计算加权均值
  3.     [m n]=size(data);
  4.     sumweight=sum(weight);
  5.     weightdata=[];
  6.     for i=1:m
  7.         weightdata(i,:)=weight(i)*data(i,:);
  8.     end
  9.     center=sum(weightdata)/sumweight;
  10.     %计算加权协方差
  11.     for i=1:n
  12.        r2(:,i)=data(:,i)-center(i);
  13.     end
  14.     for i=1:m
  15.         r1(i,:)=weight(i)*r2(i,:);
  16.     end
  17.     sigma=1/sumweight*r1’*r2;
  18.     mu=center;
  19. end
  20. function [pro] = GSMPro(mu ,sigma,x)
  21.   pro=exp(-0.5*(x-mu)*inv(sigma)*(x-mu)’);
  22.   pro=1/sqrt(2*pi*det(sigma))*pro;
  23. end

 

看以下最后的测试结果:

转载注明来源:CV视觉网 » 机器学习(四)高斯混合模型

分享到:更多 ()
扫描二维码,给作者 打赏
pay_weixinpay_weixin

请选择你看完该文章的感受:

3不错 2超赞 0无聊 0扯淡 0不解 0路过

评论 5

评论前必须登录!

 

  1. #1

    GSMData这个函数在哪里呀??谢谢啦

    王胜利2年前 (2017-03-01)