Skip to content

Commit cf94391

Browse files
committed
Merge pull request #141 from JuliaStats/anj/covcor
Update to changes in Base's cov and cor
2 parents a7de8c0 + 9af633a commit cf94391

File tree

2 files changed

+148
-56
lines changed

2 files changed

+148
-56
lines changed

src/cov.jl

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,26 +25,60 @@ end
2525

2626
scattermat_zm(x::DenseMatrix, vardim::Int) = Base.unscaled_covzm(x, vardim)
2727

28-
scattermat_zm(x::DenseMatrix, wv::WeightVec, vardim::Int) =
28+
scattermat_zm(x::DenseMatrix, wv::WeightVec, vardim::Int) =
2929
_symmetrize!(Base.unscaled_covzm(x, _scalevars(x, values(wv), vardim), vardim))
3030

31-
function scattermat(x::DenseMatrix; mean=nothing, vardim::Int=1)
32-
mean == 0 ? scattermat_zm(x, vardim) :
33-
mean == nothing ? scattermat_zm(x .- Base.mean(x, vardim), vardim) :
34-
scattermat_zm(x .- mean, vardim)
35-
end
31+
if VERSION < v"0.5.0-dev+679"
32+
function scattermat(x::DenseMatrix; mean=nothing, vardim::Int=1)
33+
mean == 0 ? scattermat_zm(x, vardim) :
34+
mean == nothing ? scattermat_zm(x .- Base.mean(x, vardim), vardim) :
35+
scattermat_zm(x .- mean, vardim)
36+
end
3637

37-
function scattermat(x::DenseMatrix, wv::WeightVec; mean=nothing, vardim::Int=1)
38-
mean == 0 ? scattermat_zm(x, wv, vardim) :
39-
mean == nothing ? scattermat_zm(x .- Base.mean(x, wv, vardim), wv, vardim) :
40-
scattermat_zm(x .- mean, wv, vardim)
41-
end
38+
function scattermat(x::DenseMatrix, wv::WeightVec; mean=nothing, vardim::Int=1)
39+
mean == 0 ? scattermat_zm(x, wv, vardim) :
40+
mean == nothing ? scattermat_zm(x .- Base.mean(x, wv, vardim), wv, vardim) :
41+
scattermat_zm(x .- mean, wv, vardim)
42+
end
43+
44+
## weighted cov
45+
Base.cov(x::DenseMatrix, wv::WeightVec; mean=nothing, vardim::Int=1) =
46+
scale!(scattermat(x, wv; mean=mean, vardim=vardim), inv(sum(wv)))
47+
48+
function mean_and_cov(x::DenseMatrix; vardim::Int=1)
49+
m = mean(x, vardim)
50+
return m, Base.covm(x, m; vardim=vardim)
51+
end
52+
function mean_and_cov(x::DenseMatrix, wv::WeightVec; vardim::Int=1)
53+
m = mean(x, wv, vardim)
54+
return m, Base.cov(x, wv; mean=m, vardim=vardim)
55+
end
56+
else
57+
scattermatm(x::DenseMatrix, mean, vardim::Int=1) =
58+
scattermat_zm(x .- mean, vardim)
59+
60+
scattermatm(x::DenseMatrix, mean, wv::WeightVec, vardim::Int=1) =
61+
scattermat_zm(x .- mean, wv, vardim)
62+
63+
scattermat(x::DenseMatrix, vardim::Int=1) =
64+
scattermatm(x, Base.mean(x, vardim), vardim)
65+
66+
scattermat(x::DenseMatrix, wv::WeightVec, vardim::Int=1) =
67+
scattermatm(x, Base.mean(x, wv, vardim), wv, vardim)
4268

43-
## weighted cov
69+
## weighted cov
70+
Base.covm(x::DenseMatrix, mean, wv::WeightVec, vardim::Int=1) =
71+
scale!(scattermatm(x, mean, wv, vardim), inv(sum(wv)))
4472

45-
Base.cov(x::DenseMatrix, wv::WeightVec; mean=nothing, vardim::Int=1) =
46-
scale!(scattermat(x, wv; mean=mean, vardim=vardim), inv(sum(wv)))
73+
Base.cov(x::DenseMatrix, wv::WeightVec, vardim::Int=1) =
74+
Base.covm(x, Base.mean(x, wv, vardim), wv, vardim)
4775

48-
mean_and_cov(x::DenseMatrix; vardim::Int=1) = (m = mean(x, vardim); (m, Base.covm(x, m; vardim=vardim)))
49-
mean_and_cov(x::DenseMatrix, wv::WeightVec; vardim::Int=1) =
50-
(m = mean(x, wv, vardim); (m, Base.cov(x, wv; mean=m, vardim=vardim)))
76+
function mean_and_cov(x::DenseMatrix, vardim::Int=1)
77+
m = mean(x, vardim)
78+
return m, Base.covm(x, m, vardim)
79+
end
80+
function mean_and_cov(x::DenseMatrix, wv::WeightVec, vardim::Int=1)
81+
m = mean(x, wv, vardim)
82+
return m, Base.cov(x, wv, vardim)
83+
end
84+
end

test/cov.jl

Lines changed: 97 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -31,61 +31,119 @@ Sz2w = X * diagm(w2) * X'
3131

3232
## scattermat
3333

34-
@test_approx_eq scattermat(X) S1
35-
@test_approx_eq scattermat(X; vardim=2) S2
34+
if VERSION < v"0.5.0-dev+679"
35+
@test_approx_eq scattermat(X) S1
36+
@test_approx_eq scattermat(X; vardim=2) S2
3637

37-
@test_approx_eq scattermat(X; mean=0) Sz1
38-
@test_approx_eq scattermat(X; mean=0, vardim=2) Sz2
38+
@test_approx_eq scattermat(X; mean=0) Sz1
39+
@test_approx_eq scattermat(X; mean=0, vardim=2) Sz2
3940

40-
@test_approx_eq scattermat(X; mean=mean(X,1)) S1
41-
@test_approx_eq scattermat(X; mean=mean(X,2), vardim=2) S2
41+
@test_approx_eq scattermat(X; mean=mean(X,1)) S1
42+
@test_approx_eq scattermat(X; mean=mean(X,2), vardim=2) S2
4243

43-
@test_approx_eq scattermat(X; mean=zeros(1,8)) Sz1
44-
@test_approx_eq scattermat(X; mean=zeros(3), vardim=2) Sz2
44+
@test_approx_eq scattermat(X; mean=zeros(1,8)) Sz1
45+
@test_approx_eq scattermat(X; mean=zeros(3), vardim=2) Sz2
4546

46-
## weighted scatter mat
47+
## weighted scatter mat
4748

48-
@test_approx_eq scattermat(X, wv1) S1w
49-
@test_approx_eq scattermat(X, wv2; vardim=2) S2w
49+
@test_approx_eq scattermat(X, wv1) S1w
50+
@test_approx_eq scattermat(X, wv2; vardim=2) S2w
5051

51-
@test_approx_eq scattermat(X, wv1; mean=0) Sz1w
52-
@test_approx_eq scattermat(X, wv2; mean=0, vardim=2) Sz2w
52+
@test_approx_eq scattermat(X, wv1; mean=0) Sz1w
53+
@test_approx_eq scattermat(X, wv2; mean=0, vardim=2) Sz2w
5354

54-
@test_approx_eq scattermat(X, wv1; mean=mean(X, wv1, 1)) S1w
55-
@test_approx_eq scattermat(X, wv2; mean=mean(X, wv2, 2), vardim=2) S2w
55+
@test_approx_eq scattermat(X, wv1; mean=mean(X, wv1, 1)) S1w
56+
@test_approx_eq scattermat(X, wv2; mean=mean(X, wv2, 2), vardim=2) S2w
5657

57-
@test_approx_eq scattermat(X, wv1; mean=zeros(1,8)) Sz1w
58-
@test_approx_eq scattermat(X, wv2; mean=zeros(3), vardim=2) Sz2w
58+
@test_approx_eq scattermat(X, wv1; mean=zeros(1,8)) Sz1w
59+
@test_approx_eq scattermat(X, wv2; mean=zeros(3), vardim=2) Sz2w
60+
else
61+
@test_approx_eq scattermat(X) S1
62+
@test_approx_eq scattermat(X, 2) S2
5963

60-
# weighted covariance
64+
@test_approx_eq StatsBase.scattermatm(X, 0) Sz1
65+
@test_approx_eq StatsBase.scattermatm(X, 0, 2) Sz2
6166

62-
@test_approx_eq cov(X, wv1) S1w ./ sum(wv1)
63-
@test_approx_eq cov(X, wv2; vardim=2) S2w ./ sum(wv2)
67+
@test_approx_eq StatsBase.scattermatm(X, mean(X,1)) S1
68+
@test_approx_eq StatsBase.scattermatm(X, mean(X,2), 2) S2
6469

65-
@test_approx_eq cov(X, wv1; mean=0) Sz1w ./ sum(wv1)
66-
@test_approx_eq cov(X, wv2; mean=0, vardim=2) Sz2w ./ sum(wv2)
70+
@test_approx_eq StatsBase.scattermatm(X, zeros(1,8)) Sz1
71+
@test_approx_eq StatsBase.scattermatm(X, zeros(3), 2) Sz2
6772

68-
@test_approx_eq cov(X, wv1; mean=mean(X, wv1, 1)) S1w ./ sum(wv1)
69-
@test_approx_eq cov(X, wv2; mean=mean(X, wv2, 2), vardim=2) S2w ./ sum(wv2)
73+
## weighted scatter mat
7074

71-
@test_approx_eq cov(X, wv1; mean=zeros(1,8)) Sz1w ./ sum(wv1)
72-
@test_approx_eq cov(X, wv2; mean=zeros(3), vardim=2) Sz2w ./ sum(wv2)
75+
@test_approx_eq scattermat(X, wv1) S1w
76+
@test_approx_eq scattermat(X, wv2, 2) S2w
7377

74-
# mean_and_cov
78+
@test_approx_eq StatsBase.scattermatm(X, 0, wv1) Sz1w
79+
@test_approx_eq StatsBase.scattermatm(X, 0, wv2, 2) Sz2w
80+
81+
@test_approx_eq StatsBase.scattermatm(X, mean(X, wv1, 1), wv1) S1w
82+
@test_approx_eq StatsBase.scattermatm(X, mean(X, wv2, 2), wv2, 2) S2w
83+
84+
@test_approx_eq StatsBase.scattermatm(X, zeros(1,8), wv1) Sz1w
85+
@test_approx_eq StatsBase.scattermatm(X, zeros(3), wv2, 2) Sz2w
86+
end
87+
88+
# weighted covariance
89+
90+
if VERSION < v"0.5.0-dev+679"
91+
@test_approx_eq cov(X, wv1) S1w ./ sum(wv1)
92+
@test_approx_eq cov(X, wv2; vardim=2) S2w ./ sum(wv2)
93+
94+
@test_approx_eq cov(X, wv1; mean=0) Sz1w ./ sum(wv1)
95+
@test_approx_eq cov(X, wv2; mean=0, vardim=2) Sz2w ./ sum(wv2)
7596

76-
(m, C) = mean_and_cov(X; vardim=1)
77-
@test m == mean(X, 1)
78-
@test C == cov(X; vardim=1)
97+
@test_approx_eq cov(X, wv1; mean=mean(X, wv1, 1)) S1w ./ sum(wv1)
98+
@test_approx_eq cov(X, wv2; mean=mean(X, wv2, 2), vardim=2) S2w ./ sum(wv2)
7999

80-
(m, C) = mean_and_cov(X; vardim=2)
81-
@test m == mean(X, 2)
82-
@test C == cov(X; vardim=2)
100+
@test_approx_eq cov(X, wv1; mean=zeros(1,8)) Sz1w ./ sum(wv1)
101+
@test_approx_eq cov(X, wv2; mean=zeros(3), vardim=2) Sz2w ./ sum(wv2)
102+
else
103+
@test_approx_eq cov(X, wv1) S1w ./ sum(wv1)
104+
@test_approx_eq cov(X, wv2, 2) S2w ./ sum(wv2)
83105

84-
(m, C) = mean_and_cov(X, wv1; vardim=1)
85-
@test m == mean(X, wv1, 1)
86-
@test C == cov(X, wv1; vardim=1)
106+
@test_approx_eq Base.covm(X, 0, wv1) Sz1w ./ sum(wv1)
107+
@test_approx_eq Base.covm(X, 0, wv2, 2) Sz2w ./ sum(wv2)
87108

88-
(m, C) = mean_and_cov(X, wv2; vardim=2)
89-
@test m == mean(X, wv2, 2)
90-
@test C == cov(X, wv2; vardim=2)
109+
@test_approx_eq Base.covm(X, mean(X, wv1, 1), wv1) S1w ./ sum(wv1)
110+
@test_approx_eq Base.covm(X, mean(X, wv2, 2), wv2, 2) S2w ./ sum(wv2)
91111

112+
@test_approx_eq Base.covm(X, zeros(1,8), wv1) Sz1w ./ sum(wv1)
113+
@test_approx_eq Base.covm(X, zeros(3), wv2, 2) Sz2w ./ sum(wv2)
114+
end
115+
116+
# mean_and_cov
117+
if VERSION < v"0.5.0-dev+679"
118+
(m, C) = mean_and_cov(X; vardim=1)
119+
@test m == mean(X, 1)
120+
@test C == cov(X; vardim=1)
121+
122+
(m, C) = mean_and_cov(X; vardim=2)
123+
@test m == mean(X, 2)
124+
@test C == cov(X; vardim=2)
125+
126+
(m, C) = mean_and_cov(X, wv1; vardim=1)
127+
@test m == mean(X, wv1, 1)
128+
@test C == cov(X, wv1; vardim=1)
129+
130+
(m, C) = mean_and_cov(X, wv2; vardim=2)
131+
@test m == mean(X, wv2, 2)
132+
@test C == cov(X, wv2; vardim=2)
133+
else
134+
(m, C) = mean_and_cov(X, 1)
135+
@test m == mean(X, 1)
136+
@test C == cov(X, 1)
137+
138+
(m, C) = mean_and_cov(X, 2)
139+
@test m == mean(X, 2)
140+
@test C == cov(X, 2)
141+
142+
(m, C) = mean_and_cov(X, wv1, 1)
143+
@test m == mean(X, wv1, 1)
144+
@test C == cov(X, wv1, 1)
145+
146+
(m, C) = mean_and_cov(X, wv2, 2)
147+
@test m == mean(X, wv2, 2)
148+
@test C == cov(X, wv2, 2)
149+
end

0 commit comments

Comments
 (0)