Skip to content

Commit d655be0

Browse files
committed
update em
1 parent 0fd9cb1 commit d655be0

File tree

3 files changed

+305
-0
lines changed

3 files changed

+305
-0
lines changed

15_EM/EM.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from numpy import *
2+
3+
def initData(k,u,sigma,dataNum):
4+
'''
5+
初始化高斯混合模型的数据
6+
k:比例系数
7+
u:均值
8+
sigma:标准差
9+
dataNum:数据个数
10+
'''
11+
dataSet=zeros(dataNum,dtype=float)
12+
#高斯分布个数
13+
n=len(k)
14+
for i in range(dataNum):
15+
#产生0-1的随机数
16+
rand=random.random()
17+
sK=0
18+
index=0
19+
while index<n:
20+
sK+=k[index]
21+
if rand<sK:
22+
dataSet[i]=random.normal(u[index],sigma[index])
23+
break
24+
else:
25+
index+=1
26+
return dataSet
27+
28+
29+
def normalFun(x,u,sigma):
30+
'''
31+
计算均值为u,标准差为sigma的正太分布函数的密度函数值
32+
'''
33+
return (1.0/sqrt(2*pi)*sigma)*(exp(-(x-u)**2/(2*sigma**2)))
34+
35+
def em(dataSet,k,u,sigma,step=10):
36+
'''
37+
高斯混合模型
38+
'''
39+
n=len(k)
40+
dataNum=len(dataArr)
41+
gamaArr=zeros((n,dataNum))
42+
for s in range(step):
43+
#E步,确定Q函数
44+
for i in range(n):
45+
for j in range(dataNum):
46+
wSum=sum([k[t]*normalFun(dataSet[j],u[t],sigma[t]) for t in range(n)])
47+
gamaArr[i][j]=k[i]*normalFun(dataSet[j],u[i],sigma[i])/float(wSum)
48+
49+
#M步
50+
#更新u
51+
for i in range(n):
52+
u[i]=sum(gamaArr[i]*dataSet)/sum(gamaArr[i])
53+
#更新sigma
54+
for i in range(n):
55+
sigma[i]=sqrt(sum(gamaArr[i]*(dataSet-u[i])**2)/sum(gamaArr[i]))
56+
#更新k
57+
for i in range(n):
58+
k[i]=sum(gamaArr[i])/dataNum
59+
60+
return [k,u,sigma]
61+
62+
if __name__=='__main__':
63+
#参数的准确值
64+
k=[0.3,0.4,0.3]
65+
u=[2,4,3]
66+
sigma=[1,1,4]
67+
#样本数
68+
dataNum=5000
69+
dataArr=initData(k,u,sigma,dataNum)
70+
71+
k0=[0.3,0.3,0.4]
72+
u0=[1,2,2]
73+
sigma0=[1,1,1]
74+
step=100
75+
76+
k1,u1,sigma1=em(dataArr,k0,u0,sigma0,step)
77+
print("参数实际值:")
78+
print("k:",k)
79+
print("u:",u)
80+
print("sigma:",sigma)
81+
82+
print("参数估计值:")
83+
print("k1:",k1)
84+
print("u1:",u1)
85+
print("sigma1:",sigma1)
86+
87+

15_EM/README.md

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# EM算法(期望最大算法)
2+
3+
## 概述
4+
5+
EM算法是常用的估计参数隐变量的利器,它是一种迭代式的方法,其基本想法是:若参数已知,则可根据训练数据推断出最优隐变量的值;若最优隐变量的值已知,则可对参数做最大似然估计
6+
7+
## Jensen不等式
8+
9+
若f是定义域为实数的函数,如果对于所有实数x,f''(x)>=0,那么f是凸函数。当x是向量时,如果其hession矩阵是H是半正定的(H>=0),那么f是凸函数。如果f''(x)>0或者H>0,那么f是严格的凸函数。
10+
11+
Jensen不等式:
12+
13+
如果f是凸函数,X是随机变量,那么E[f(X)]>=f(EX)
14+
15+
如果f是严格凸函数,那么E[f(X)]=f(EX)当且仅当p(x=E[X])=1
16+
17+
如图:
18+
19+
![](https://github.com/TonyJent/myMachineLearning/blob/master/images/15_EM/Jensen%E4%B8%8D%E7%AD%89%E5%BC%8F.PNG)
20+
21+
## EM算法
22+
23+
![](https://github.com/TonyJent/myMachineLearning/blob/master/images/15_EM/EM%E7%AE%97%E6%B3%95E%E6%AD%A5.PNG)
24+
25+
![](https://github.com/TonyJent/myMachineLearning/blob/master/images/15_EM/EM%E7%AE%97%E6%B3%95M%E6%AD%A51.PNG)
26+
27+
![](https://github.com/TonyJent/myMachineLearning/blob/master/images/15_EM/EM%E7%AE%97%E6%B3%95M%E6%AD%A52.PNG)
28+
29+
## EM算法在高斯混合模型中的应用
30+
31+
![](https://github.com/TonyJent/myMachineLearning/blob/master/images/15_EM/%E9%AB%98%E6%96%AF%E6%B7%B7%E5%90%88%E6%A8%A1%E5%9E%8B1.PNG)
32+
33+
![](https://github.com/TonyJent/myMachineLearning/blob/master/images/15_EM/%E9%AB%98%E6%96%AF%E6%B7%B7%E5%90%88%E6%A8%A1%E5%9E%8B2.PNG)
34+
35+
![](https://github.com/TonyJent/myMachineLearning/blob/master/images/15_EM/%E9%AB%98%E6%96%AF%E6%B7%B7%E5%90%88%E6%A8%A1%E5%9E%8B3.PNG)
36+
37+
![](https://github.com/TonyJent/myMachineLearning/blob/master/images/15_EM/%E9%AB%98%E6%96%AF%E6%B7%B7%E5%90%88%E6%A8%A1%E5%9E%8B4.PNG)
38+
39+
![](https://github.com/TonyJent/myMachineLearning/blob/master/images/15_EM/%E9%AB%98%E6%96%AF%E6%B7%B7%E5%90%88%E6%A8%A1%E5%9E%8B5.PNG)
40+
41+
![](https://github.com/TonyJent/myMachineLearning/blob/master/images/15_EM/%E9%AB%98%E6%96%AF%E6%B7%B7%E5%90%88%E6%A8%A1%E5%9E%8B6.PNG)
42+
43+
![](https://github.com/TonyJent/myMachineLearning/blob/master/images/15_EM/%E9%AB%98%E6%96%AF%E6%B7%B7%E5%90%88%E6%A8%A1%E5%9E%8B7.PNG)
44+
45+
## 高斯混合模型算法实现
46+
47+
```python
48+
from numpy import *
49+
50+
def initData(k,u,sigma,dataNum):
51+
'''
52+
初始化高斯混合模型的数据
53+
k:比例系数
54+
u:均值
55+
sigma:标准差
56+
dataNum:数据个数
57+
'''
58+
dataSet=zeros(dataNum,dtype=float)
59+
#高斯分布个数
60+
n=len(k)
61+
for i in range(dataNum):
62+
#产生0-1的随机数
63+
rand=random.random()
64+
sK=0
65+
index=0
66+
while index<n:
67+
sK+=k[index]
68+
if rand<sK:
69+
dataSet[i]=random.normal(u[index],sigma[index])
70+
break
71+
else:
72+
index+=1
73+
return dataSet
74+
75+
76+
def normalFun(x,u,sigma):
77+
'''
78+
计算均值为u,标准差为sigma的正太分布函数的密度函数值
79+
'''
80+
return (1.0/sqrt(2*pi)*sigma)*(exp(-(x-u)**2/(2*sigma**2)))
81+
82+
def em(dataSet,k,u,sigma,step=10):
83+
'''
84+
高斯混合模型
85+
'''
86+
n=len(k)
87+
dataNum=len(dataArr)
88+
gamaArr=zeros((n,dataNum))
89+
for s in range(step):
90+
#E步,确定Q函数
91+
for i in range(n):
92+
for j in range(dataNum):
93+
wSum=sum([k[t]*normalFun(dataSet[j],u[t],sigma[t]) for t in range(n)])
94+
gamaArr[i][j]=k[i]*normalFun(dataSet[j],u[i],sigma[i])/float(wSum)
95+
96+
#M步
97+
#更新u
98+
for i in range(n):
99+
u[i]=sum(gamaArr[i]*dataSet)/sum(gamaArr[i])
100+
#更新sigma
101+
for i in range(n):
102+
sigma[i]=sqrt(sum(gamaArr[i]*(dataSet-u[i])**2)/sum(gamaArr[i]))
103+
#更新k
104+
for i in range(n):
105+
k[i]=sum(gamaArr[i])/dataNum
106+
107+
return [k,u,sigma]
108+
```
109+
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# EM算法(期望最大算法)
2+
3+
## 概述
4+
5+
EM算法是常用的估计参数隐变量的利器,它是一种迭代式的方法,其基本想法是:若参数已知,则可根据训练数据推断出最优隐变量的值;若最优隐变量的值已知,则可对参数做最大似然估计
6+
7+
## Jensen不等式
8+
9+
若f是定义域为实数的函数,如果对于所有实数x,f''(x)>=0,那么f是凸函数。当x是向量时,如果其hession矩阵是H是半正定的(H>=0),那么f是凸函数。如果f''(x)>0或者H>0,那么f是严格的凸函数。
10+
11+
Jensen不等式:
12+
13+
如果f是凸函数,X是随机变量,那么E[f(X)]>=f(EX)
14+
15+
如果f是严格凸函数,那么E[f(X)]=f(EX)当且仅当p(x=E[X])=1
16+
17+
如图:
18+
19+
![](https://github.com/TonyJent/myMachineLearning/blob/master/images/15_EM/Jensen%E4%B8%8D%E7%AD%89%E5%BC%8F.PNG)
20+
21+
## EM算法
22+
23+
![](https://github.com/TonyJent/myMachineLearning/blob/master/images/15_EM/EM%E7%AE%97%E6%B3%95E%E6%AD%A5.PNG)
24+
25+
![](https://github.com/TonyJent/myMachineLearning/blob/master/images/15_EM/EM%E7%AE%97%E6%B3%95M%E6%AD%A51.PNG)
26+
27+
![](https://github.com/TonyJent/myMachineLearning/blob/master/images/15_EM/EM%E7%AE%97%E6%B3%95M%E6%AD%A52.PNG)
28+
29+
## EM算法在高斯混合模型中的应用
30+
31+
![](https://github.com/TonyJent/myMachineLearning/blob/master/images/15_EM/%E9%AB%98%E6%96%AF%E6%B7%B7%E5%90%88%E6%A8%A1%E5%9E%8B1.PNG)
32+
33+
![](https://github.com/TonyJent/myMachineLearning/blob/master/images/15_EM/%E9%AB%98%E6%96%AF%E6%B7%B7%E5%90%88%E6%A8%A1%E5%9E%8B2.PNG)
34+
35+
![](https://github.com/TonyJent/myMachineLearning/blob/master/images/15_EM/%E9%AB%98%E6%96%AF%E6%B7%B7%E5%90%88%E6%A8%A1%E5%9E%8B3.PNG)
36+
37+
![](https://github.com/TonyJent/myMachineLearning/blob/master/images/15_EM/%E9%AB%98%E6%96%AF%E6%B7%B7%E5%90%88%E6%A8%A1%E5%9E%8B4.PNG)
38+
39+
![](https://github.com/TonyJent/myMachineLearning/blob/master/images/15_EM/%E9%AB%98%E6%96%AF%E6%B7%B7%E5%90%88%E6%A8%A1%E5%9E%8B5.PNG)
40+
41+
![](https://github.com/TonyJent/myMachineLearning/blob/master/images/15_EM/%E9%AB%98%E6%96%AF%E6%B7%B7%E5%90%88%E6%A8%A1%E5%9E%8B6.PNG)
42+
43+
![](https://github.com/TonyJent/myMachineLearning/blob/master/images/15_EM/%E9%AB%98%E6%96%AF%E6%B7%B7%E5%90%88%E6%A8%A1%E5%9E%8B7.PNG)
44+
45+
## 高斯混合模型算法实现
46+
47+
```python
48+
from numpy import *
49+
50+
def initData(k,u,sigma,dataNum):
51+
'''
52+
初始化高斯混合模型的数据
53+
k:比例系数
54+
u:均值
55+
sigma:标准差
56+
dataNum:数据个数
57+
'''
58+
dataSet=zeros(dataNum,dtype=float)
59+
#高斯分布个数
60+
n=len(k)
61+
for i in range(dataNum):
62+
#产生0-1的随机数
63+
rand=random.random()
64+
sK=0
65+
index=0
66+
while index<n:
67+
sK+=k[index]
68+
if rand<sK:
69+
dataSet[i]=random.normal(u[index],sigma[index])
70+
break
71+
else:
72+
index+=1
73+
return dataSet
74+
75+
76+
def normalFun(x,u,sigma):
77+
'''
78+
计算均值为u,标准差为sigma的正太分布函数的密度函数值
79+
'''
80+
return (1.0/sqrt(2*pi)*sigma)*(exp(-(x-u)**2/(2*sigma**2)))
81+
82+
def em(dataSet,k,u,sigma,step=10):
83+
'''
84+
高斯混合模型
85+
'''
86+
n=len(k)
87+
dataNum=len(dataArr)
88+
gamaArr=zeros((n,dataNum))
89+
for s in range(step):
90+
#E步,确定Q函数
91+
for i in range(n):
92+
for j in range(dataNum):
93+
wSum=sum([k[t]*normalFun(dataSet[j],u[t],sigma[t]) for t in range(n)])
94+
gamaArr[i][j]=k[i]*normalFun(dataSet[j],u[i],sigma[i])/float(wSum)
95+
96+
#M步
97+
#更新u
98+
for i in range(n):
99+
u[i]=sum(gamaArr[i]*dataSet)/sum(gamaArr[i])
100+
#更新sigma
101+
for i in range(n):
102+
sigma[i]=sqrt(sum(gamaArr[i]*(dataSet-u[i])**2)/sum(gamaArr[i]))
103+
#更新k
104+
for i in range(n):
105+
k[i]=sum(gamaArr[i])/dataNum
106+
107+
return [k,u,sigma]
108+
```
109+

0 commit comments

Comments
 (0)