|
| 1 | +# EM算法(期望最大算法) |
| 2 | + |
| 3 | +## 概述 |
| 4 | + |
| 5 | +EM算法是常用的估计参数隐变量的利器,它是一种迭代式的方法,其基本想法是:若参数已知,则可根据训练数据推断出最优隐变量的值;若最优隐变量的值已知,则可对参数做最大似然估计 |
| 6 | + |
| 7 | +## Jensen不等式 |
| 8 | + |
| 9 | +若f是定义域为实数的函数,如果对于所有实数x,f''(x)>=0,那么f是凸函数。当x是向量时,如果其hession矩阵是H是半正定的(H>=0),那么f是凸函数。如果f''(x)>0或者H>0,那么f是严格的凸函数。 |
| 10 | + |
| 11 | +Jensen不等式: |
| 12 | + |
| 13 | +如果f是凸函数,X是随机变量,那么E[f(X)]>=f(EX) |
| 14 | + |
| 15 | +如果f是严格凸函数,那么E[f(X)]=f(EX)当且仅当p(x=E[X])=1 |
| 16 | + |
| 17 | +如图: |
| 18 | + |
| 19 | + |
| 20 | + |
| 21 | +## EM算法 |
| 22 | + |
| 23 | + |
| 24 | + |
| 25 | + |
| 26 | + |
| 27 | + |
| 28 | + |
| 29 | +## EM算法在高斯混合模型中的应用 |
| 30 | + |
| 31 | + |
| 32 | + |
| 33 | + |
| 34 | + |
| 35 | + |
| 36 | + |
| 37 | + |
| 38 | + |
| 39 | + |
| 40 | + |
| 41 | + |
| 42 | + |
| 43 | + |
| 44 | + |
| 45 | +## 高斯混合模型算法实现 |
| 46 | + |
| 47 | +```python |
| 48 | +from numpy import * |
| 49 | + |
| 50 | +def initData(k,u,sigma,dataNum): |
| 51 | + ''' |
| 52 | + 初始化高斯混合模型的数据 |
| 53 | + k:比例系数 |
| 54 | + u:均值 |
| 55 | + sigma:标准差 |
| 56 | + dataNum:数据个数 |
| 57 | + ''' |
| 58 | + dataSet=zeros(dataNum,dtype=float) |
| 59 | + #高斯分布个数 |
| 60 | + n=len(k) |
| 61 | + for i in range(dataNum): |
| 62 | + #产生0-1的随机数 |
| 63 | + rand=random.random() |
| 64 | + sK=0 |
| 65 | + index=0 |
| 66 | + while index<n: |
| 67 | + sK+=k[index] |
| 68 | + if rand<sK: |
| 69 | + dataSet[i]=random.normal(u[index],sigma[index]) |
| 70 | + break |
| 71 | + else: |
| 72 | + index+=1 |
| 73 | + return dataSet |
| 74 | + |
| 75 | + |
| 76 | +def normalFun(x,u,sigma): |
| 77 | + ''' |
| 78 | + 计算均值为u,标准差为sigma的正太分布函数的密度函数值 |
| 79 | + ''' |
| 80 | + return (1.0/sqrt(2*pi)*sigma)*(exp(-(x-u)**2/(2*sigma**2))) |
| 81 | + |
| 82 | +def em(dataSet,k,u,sigma,step=10): |
| 83 | + ''' |
| 84 | + 高斯混合模型 |
| 85 | + ''' |
| 86 | + n=len(k) |
| 87 | + dataNum=len(dataArr) |
| 88 | + gamaArr=zeros((n,dataNum)) |
| 89 | + for s in range(step): |
| 90 | + #E步,确定Q函数 |
| 91 | + for i in range(n): |
| 92 | + for j in range(dataNum): |
| 93 | + wSum=sum([k[t]*normalFun(dataSet[j],u[t],sigma[t]) for t in range(n)]) |
| 94 | + gamaArr[i][j]=k[i]*normalFun(dataSet[j],u[i],sigma[i])/float(wSum) |
| 95 | + |
| 96 | + #M步 |
| 97 | + #更新u |
| 98 | + for i in range(n): |
| 99 | + u[i]=sum(gamaArr[i]*dataSet)/sum(gamaArr[i]) |
| 100 | + #更新sigma |
| 101 | + for i in range(n): |
| 102 | + sigma[i]=sqrt(sum(gamaArr[i]*(dataSet-u[i])**2)/sum(gamaArr[i])) |
| 103 | + #更新k |
| 104 | + for i in range(n): |
| 105 | + k[i]=sum(gamaArr[i])/dataNum |
| 106 | + |
| 107 | + return [k,u,sigma] |
| 108 | +``` |
| 109 | + |
0 commit comments