1
2
3
4
5
6
7
8
9
10
11
12
13 package org.astrogrid.cluster.cluster;
14 import no.uib.cipr.matrix.AGDenseMatrix;
15 import no.uib.cipr.matrix.DenseVector;
16 import no.uib.cipr.matrix.Vector;
17
18 import org.astrogrid.matrix.Matrix;
19 import static org.astrogrid.matrix.MatrixUtils.*;
20 import static org.astrogrid.matrix.Algorithms.*;
21 import static java.lang.Math.*;
22
23 public class ClusterMStep {
24 public static class Retval {
25 public final AGDenseMatrix mu;
26 public final AGDenseMatrix cv;
27 public final AGDenseMatrix lmu;
28 public final AGDenseMatrix lcv;
29 public final Vector p;
30 public Retval(AGDenseMatrix mu,AGDenseMatrix cv, AGDenseMatrix lmu, AGDenseMatrix lcv, Vector p ) {
31
32 this.mu = mu;
33 this.cv = cv;
34 this.lmu = lmu;
35 this.lcv = lcv;
36 this.p = p;
37 }
38
39 }
40
41
42
43
44
45 static Retval clustering_m_step(Matrix alldata, Matrix datatype, int K, AGDenseMatrix q,
46 AGDenseMatrix lcv, CovarianceKind cv_type){
47
48
49
50
51 int ndata = alldata.numRows();
52 int no_of_data_types = datatype.numRows();
53 Matrix data_nr = null, data_bin = null, data_mul = null, data_int = null, data_er = null;
54 Vector gcv_c = new DenseVector(no_of_data_types);
55 Matrix S = null;
56 int ndim_nr = 0, ndim_er = 0,ndim_bin = 0,ndim_mul = 0,ndim_int = 0;
57
58 int n1 = 0, d = 0;
59 Matrix qcv[][] = new AGDenseMatrix[ndata][K];
60 DenseVector qmu[][] = new DenseVector[ndata][K];
61 Matrix gmu_nr = null,gcv_nr[] = new AGDenseMatrix[K], gcv_nr_d = null, gmu = null, gcv_f[] = new AGDenseMatrix[K], gcv_d = null;
62 for ( int i = 0; i < no_of_data_types; i++){
63 if(datatype.get(i,0) == 1 ){
64 if((ndim_nr = (int)datatype.get(i,1)) != 0){
65 data_nr = alldata.sliceCol(d,ndim_nr);
66 gcv_nr_d = new AGDenseMatrix(K,ndim_nr);
67 d = d + ndim_nr;
68 }
69 }
70 else if(datatype.get(i,0) == 2 ){
71 if((ndim_er = (int)datatype.get(i,1)) != 0){
72 data_er = alldata.sliceCol(d,ndim_er);
73 gcv_d = new AGDenseMatrix(K, ndim_er);
74 switch(cv_type) {
75
76 case free:
77 for ( int k = 0; k < K; k++){
78 gcv_f[k] = reshape(lcv.asVector(n1,n1+ndim_er*ndim_er-1),
79 ndim_er,ndim_er);
80 n1 = n1 + ndim_er*ndim_er;
81 }
82 break;
83 case diagonal:
84 for ( int k = 0; k < K; k++){
85 gcv_d.setRow(k, reshape(lcv.asVector(n1,n1+ndim_er-1), 1, ndim_er));
86 n1 = n1 + ndim_er;
87 }
88 break;
89 case common:
90 gcv_c = lcv.asVector(n1);
91 }
92 d = d + ndim_er;
93 }
94 }
95 else if (datatype.get(i,0) == 3) {
96 if((ndim_bin = (int)datatype.get(i,1)) != 0){
97 data_bin = alldata.sliceCol(d,ndim_bin);
98 d = d + ndim_bin;
99 }
100 }
101 else if (datatype.get(i,0) == 4) {
102 if((ndim_mul =(int) datatype.get(i,1)) != 0){
103 data_mul = alldata.sliceCol(d,ndim_mul);
104 d = d + ndim_mul;
105 }
106 }
107 else if (datatype.get(i,0) == 5) {
108 if((ndim_int = (int)datatype.get(i,1)) != 0){
109 data_int = alldata.sliceCol(d,ndim_int);
110 d = d + ndim_int;
111 }
112 }
113 else if (datatype.get(i,0) == 6) {
114 int ndim_error =(int) datatype.get(i,1);
115 if(ndim_error != ndim_er){
116 throw new IllegalArgumentException( "The dimension of measurement errors and ");
117 }
118 S = alldata.sliceCol(d,ndim_error);
119
120 S.add(1.0e-6, ones(ndata, ndim_er));
121 d = d + ndim_error;
122 }
123 }
124
125
126
127
128
129 AGDenseMatrix mu = new AGDenseMatrix(0,0) ,cv = new AGDenseMatrix(0,0), lmu = new AGDenseMatrix(0,0);
130 lcv = new AGDenseMatrix(0,0);
131 gmu = new AGDenseMatrix(K, ndim_er);
132 if(ndim_er != 0 ){
133 for ( int k = 0; k < K; k++){
134 AGDenseMatrix tmpg = zeros(ndim_er,1);
135 AGDenseMatrix tmpt = zeros(ndim_er, ndim_er);
136 for ( int n = 0; n < ndata; n++){
137 switch(cv_type) {
138 case free:
139 tmpg.add(q.get(n,k),multBt(inv(add((gcv_f[k]),diag(
140 S.sliceRow(n)))),data_er.sliceRowM(n)));
141 tmpt.add(q.get(n,k),inv(add((gcv_f[k]),diag(S.sliceRow(n)))));
142 break;
143 case diagonal:
144 tmpg.add( q.get(n,k),multBt(inv(add(diag(gcv_d.sliceRow(k)),diag(S.sliceRow(n))))
145 ,data_er.sliceRowM(n)));
146 tmpt.add(q.get(n,k),inv(add(diag(gcv_d.sliceRow(k)),diag(S.sliceRow(n)))));
147 break;
148 case common:
149 tmpg.add(q.get(n,k),multBt(inv(add(gcv_c.get(k),diag(S.sliceRow(n)))),data_er.sliceRowM(n)));
150 tmpt.add(q.get(n,k),inv(add(gcv_c.get(k),diag(S.sliceRow(n)))));
151 break;
152 }
153 }
154 gmu.setRow(k , mult(inv(tmpt),tmpg).asVector());
155 }
156 lmu.append( gmu.asVector());
157
158
159
160 for ( int k = 0; k < K; k++){
161 switch(cv_type) {
162
163 case free:
164 Matrix cvk = pow((gcv_f[k]),0.5);
165 Vector tmpcv = Minimize.minimize("objectiveFull", 10, cvk.asVector(), q.sliceCol(k),
166 gmu.sliceRow(k), data_er, S);
167 Matrix tmpcvm = pow(reshape(tmpcv,ndim_er,ndim_er),2.0);
168 lcv.append(tmpcvm.asVector());
169 break;
170 case diagonal:
171
172 Vector cvkv = pow(gcv_d.sliceRow(k),0.5);
173 tmpcv =Minimize.minimize("objectiveDiag",10,cvkv,q.sliceCol(k),gmu.sliceRow(k),
174 data_er, S);
175 lcv.append(pow(tmpcv,2));
176 break;
177
178 case common:
179 Vector cvk_c = new DenseVector(new double[]{gcv_c.get(k)});
180 tmpcv = Minimize.minimize("objectiveSpherical",10,cvk_c,q.sliceCol(k), gmu.sliceRow(k),
181 data_er, S);
182 gcv_c.set(k , pow(tmpcv.get(0),2.0));
183 lcv.append(gcv_c.get(k));
184 }
185 }
186 }
187
188 if(ndim_nr != 0){
189
190
191
192
193
194 gmu_nr = new AGDenseMatrix(K,ndim_nr);
195 for ( int k = 0; k < K; k++){
196 Vector tmp = new DenseVector(ndim_nr).zero();
197 for ( int n = 0; n < ndata; n++){
198 tmp.add(q.get(n,k),data_nr.sliceRow(n));
199 }
200 gmu_nr.setRow(k, tmp.scale(1.0/sum(q.sliceCol(k))));
201 }
202 mu.append( gmu_nr.asVector());
203
204
205
206 for ( int k = 0; k < K; k++){
207 switch(cv_type) {
208
209
210 case free:
211 Matrix gcvt = zeros(ndim_nr, ndim_nr);
212 for ( int n = 0; n < ndata; n++){
213 Vector tmp = sub(gmu_nr.sliceRow(k),data_nr.sliceRow(n));
214
215 gcvt.add(q.get(n,k),vprod(tmp));
216 }
217 gcv_nr[k] = (Matrix) gcvt.scale(1.0 / sum(add(q.sliceCol(k),eps)));
218
219 cv.append(gcv_nr[k].asVector());
220 break;
221 case diagonal:
222 for ( int i = 0; i < ndim_nr; i++){
223 double tmp = 0.0;
224 for ( int n = 0; n < ndata; n++){
225 tmp = tmp + q.get(n,k)*pow(data_nr.get(n,i)-gmu_nr.get(k,i),2.0);
226 }
227 gcv_nr_d.set(k,i, tmp/ sum(add(q.sliceCol(k),eps)));
228 }
229 cv.append(gcv_nr_d.sliceRow(k));
230 break;
231 case common:
232 double tmpt = 0.0;
233 for ( int n = 0; n < ndata; n++){
234
235 Vector tmp = sub(data_nr.sliceRow(n), gmu_nr.sliceRow(k));
236 tmpt=tmpt+q.get(n,k)*tmp.dot(tmp);
237 }
238 tmpt = tmpt/(ndim_nr*sum(add(q.sliceCol(k),eps)));
239 if(tmpt < eps){
240 tmpt = eps;
241 }
242 cv.append(tmpt);
243
244 }
245 }
246 }
247
248
249
250
251
252 if(ndim_bin != 0){
253
254 int bin_dim = data_bin.numColumns();
255 Matrix bp = new AGDenseMatrix(K,bin_dim);
256 ndata = data_bin.numRows();
257 for ( int k = 0; k < K; k++){
258 for ( int j = 0; j < ndim_bin; j++){
259
260
261 bp.set(k,j, sum(times(q.sliceCol(k),data_bin.sliceCol(j)))/(sum(add(q.sliceCol(k),eps))));
262 }
263 }
264 mu.append(bp.asVector());
265 }
266
267
268 if(ndim_mul != 0){
269 int dim_mul = data_mul.numColumns();
270 ndata = data_mul.numRows();
271 Matrix mp = new AGDenseMatrix(K,dim_mul);
272 for ( int k = 0; k < K; k++){
273 for ( int i = 0; i < dim_mul; i++){
274 mp.set(k,i, sum(times(q.sliceCol(k),data_mul.sliceCol(i))));
275 }
276 mp.setRow(k, mp.sliceRow(k).scale(1.0/sum(mp.sliceRow(k))));
277 }
278 mu.append(mp.asVector());
279 }
280
281 if(ndim_int != 0){
282 int dim_int = data_int.numColumns();
283 Matrix ip = new AGDenseMatrix(K,dim_int);
284 for ( int k = 0; k < K; k++){
285 for ( int i = 0; i < dim_int; i++){
286 ip.set(k,i , sum(times(q.sliceCol(k),data_int.sliceCol(i)))/sum(q.sliceCol(k)));
287 }
288 }
289 mu.append(ip.asVector());
290 }
291
292
293
294
295 Vector p = mean(q, 1);
296
297
298 return new Retval(mu, cv, lmu, lcv, p);
299
300 }
301
302 }
303
304
305
306
307
308
309
310
311
312
313
314
315
316