1
2
3
4
5
6
7
8
9
10
11
12
13 package org.astrogrid.cluster.cluster;
14
15 import no.uib.cipr.matrix.AGDenseMatrix;
16 import no.uib.cipr.matrix.DenseVector;
17 import no.uib.cipr.matrix.MatrixEntry;
18 import no.uib.cipr.matrix.Vector;
19
20 import org.astrogrid.matrix.Matrix;
21 import static org.astrogrid.matrix.MatrixUtils.*;
22 import static org.astrogrid.matrix.Algorithms.*;
23 import static java.lang.Math.*;
24
25
26
27
28
29
30
31
32 public class ClusterEStepFull {
33
34
35
36
37
38
39
40
41
42 public static class Retval {
43 public final AGDenseMatrix output;
44 public final AGDenseMatrix ab;
45 public final AGDenseMatrix q;
46 public final AGDenseMatrix C;
47
48 public Retval(AGDenseMatrix output, AGDenseMatrix ab, AGDenseMatrix q, AGDenseMatrix C) {
49 this.output = output;
50 this.ab = ab;
51 this.q = q;
52 this.C = C;
53 }
54 }
55
56 public static Retval clustering_e_step_full(Matrix data, Matrix datatype, int K, Matrix latent,
57 AGDenseMatrix ab, Matrix mu, Matrix cv, Matrix lmu, Matrix lcv, Vector p, CovarianceKind cv_type
58
59 ){
60
61
62
63 int ndata = data.numRows();
64 int no_of_data_types = datatype.numRows();
65 int outpos = 0;
66 int abpos= 0;
67
68 int n0 = 0, n1 = 0, nm = 0, ne = 0, d = 0;
69 Matrix data_nr = null, data_er = null, data_bin = null, data_mul = null, data_int = null;
70 Matrix gmu_nr = null,gcv_nr_f[] = new AGDenseMatrix[K], gcv_nr_d = null, gmu = null, gcv_f[] = new AGDenseMatrix[K], gcv_d = null, S = null;
71 Vector gcv_nr_c= null, gcv_c = null;
72 Matrix bp = null, mp = null, ip = null;
73 AGDenseMatrix C = new AGDenseMatrix(ndata, K);
74
75 int ndim_nr = 0, ndim_er = 0,ndim_bin=0,ndim_mul = 0,ndim_int = 0;
76 for (int i = 0; i < no_of_data_types; i++){
77 if (datatype.get(i,0) == 1){
78 if((ndim_nr = (int) datatype.get(i,1)) > 0){
79 data_nr = data.sliceCol(d,ndim_nr );
80 gmu_nr = reshape(mu.asVector(nm,nm+K*ndim_nr-1), K, ndim_nr);
81 nm = nm + K*ndim_nr;
82 gcv_nr_d = new AGDenseMatrix(K,ndim_nr);
83 switch (cv_type){
84 case free:
85 for (int k = 0; k < K; k++){
86 gcv_nr_f[k] = reshape(cv.asVector(n0,n0+ndim_nr*ndim_nr -1),
87 ndim_nr, ndim_nr);
88 n0 = n0 + ndim_nr *ndim_nr;
89 }
90 break;
91 case diagonal:
92 for( int k=0; k < K; k++){
93 gcv_nr_d.setRow(k, reshape(cv.asVector(n0,n0+ndim_nr -1), 1, ndim_nr));
94 n0 = n0 + ndim_nr;
95 }
96 break;
97 case common:
98 gcv_nr_c = cv.asVector(n0);
99 n0 = n0 + K;
100 break;
101 }
102 }
103 d = d + ndim_nr;
104 }
105 else if (datatype.get(i,0) == 2){
106 if((ndim_er = (int) datatype.get(i,1))!=0){
107 data_er = data.sliceCol(d,ndim_er);
108 gmu = reshape(lmu.asVector(ne,ne+K*ndim_er-1),K, ndim_er);
109 ne = ne + K*ndim_er;
110 gcv_d = new AGDenseMatrix(K,ndim_er);
111 switch (cv_type) {
112 case free:
113 for (int k = 0; k < K ; k++){
114 gcv_f[k] = reshape(lcv.asVector(n1,n1+ndim_er*ndim_er-1),
115 ndim_er,ndim_er);
116 n1 = n1 + ndim_er * ndim_er;
117 }
118 break;
119 case diagonal:
120 for(int k=0; k < K; k++){
121 gcv_d.setRow(k,reshape(lcv.asVector(n1,n1+ndim_er-1), 1, ndim_er));
122 n1 = n1 + ndim_er;
123 }
124 break;
125 case common:
126 gcv_c = lcv.asVector(n1);
127 break;
128 }
129 d = d + ndim_er;
130 }
131 }
132 else if (datatype.get(i,0) == 3){
133 ndim_bin = (int) datatype.get(i,1);
134 data_bin = data.sliceCol(d,ndim_bin);
135 bp = reshape(mu.asVector(nm,nm+K*ndim_bin-1), K, ndim_bin);
136 nm = nm + K*ndim_bin;
137 d = d + ndim_bin;
138 }
139 else if (datatype.get(i,0) == 4){
140 ndim_mul = (int) datatype.get(i,1);
141 data_mul = data.sliceCol(d, ndim_mul );
142 mp = reshape(mu.asVector(nm,nm+K*ndim_mul -1), K, ndim_mul);
143 nm = nm + K*ndim_mul;
144 d = d + ndim_mul;
145 }
146 else if (datatype.get(i,0) == 5){
147 ndim_int = (int) datatype.get(i,1);
148 data_int = data.sliceCol(d,ndim_int);
149 ip = reshape(mu.asVector(nm,nm+K*ndim_int-1), K, ndim_int);
150 nm = nm + K*ndim_int;
151 d= d + ndim_int;
152 }
153 else if (datatype.get(i,0) == 6){
154 int ndim_error = (int) datatype.get(i,1);
155 if (ndim_error != ndim_er){
156 throw new IllegalArgumentException( "The dimension of measurement errors and ");
157 }
158 S = data.sliceCol(d,ndim_error);
159
160 S.add(ones(ndata, ndim_er).scale(1.0e-6));
161 d = d + ndim_error;
162 }
163 }
164
165 AGDenseMatrix a = null, b = null;
166 Vector v = null;
167 if (ndim_er != 0 || ndim_nr != 0){
168 a = reshape(ab.asVector(0,ndata*K-1), ndata, K);
169 b = reshape(ab.asVector(ndata*K), ndata, K);
170 v = reshape(latent, K, 1).asVector();
171 }
172
173
174
175
176
177
178 int ndim= data.numColumns();
179 AGDenseMatrix output = new AGDenseMatrix(ndim_er*(ndim_er+1)*K*ndata,1);
180
181 AGDenseMatrix qcv[][]= new AGDenseMatrix[ndata][K];
182 DenseVector qmu[][]= new DenseVector[ndata][K];
183 if (ndim_er != 0) {
184
185 Matrix epu = divide(a , b);
186
187
188 for (int n=0 ; n <ndata; n++){
189 for (int k =0; k <K; k++){
190 switch (cv_type){
191 case free:
192 Matrix gcv_fe = new AGDenseMatrix(gcv_f[k]);
193 gcv_fe.scale(1.0/epu.get(n,k));
194 qcv[n][k] = mult(mult(diag(S.sliceRow(n)),inv(add(diag(S.sliceRow(n)),gcv_fe))),gcv_fe);
195 qmu[n][k] = add(mult(gmu.sliceRowM(k),mult(diag(S.sliceRow(n)),inv(add(diag(S.sliceRow(n)),gcv_fe))))
196
197 ,mult(data_er.sliceRowM(n),mult(gcv_fe,inv(add(diag(S.sliceRow(n)),gcv_fe))))).asVector();
198 outpos = output.insert( qcv[n][k], outpos);
199 outpos = output.insert(qmu[n][k], outpos);
200
201 break;
202 case diagonal:
203 Matrix gcv_scale = (Matrix) diag(gcv_d.sliceRow(k)).scale(1.0/epu.get(n,k));
204 qcv[n][k]=mult(mult(diag(S.sliceRow(n)),inv(add(gcv_scale,
205 diag(S.sliceRow(n))))),gcv_scale);
206 qmu[n][k]=add(multBt(gmu.sliceRowM(k),mult(diag(S.sliceRow(n)),inv(add(gcv_scale,diag(S.sliceRow(n))))))
207 ,multBt(data_er.sliceRowM(n),mult(gcv_scale,inv(add(gcv_scale,diag(S.sliceRow(n))))))).asVector();
208 outpos = output.insert( qcv[n][k], outpos);
209 outpos = output.insert(qmu[n][k], outpos);
210 break;
211 case common:
212 no.uib.cipr.matrix.Matrix tcv = eye(ndim_er).scale(gcv_c.get(k)/epu.get(n,k));
213 qcv[n][k]=mult(diag(S.sliceRow(n)).scale((gcv_c.get(k))/epu.get(n,k)),inv(
214 diag(S.sliceRow(n)).add(gcv_c.get(k)/epu.get(n,k))));
215
216
217 qmu[n][k]=add(multBt(gmu.sliceRowM(k),mult(diag(S.sliceRow(n)),inv(add(tcv ,diag(S.sliceRow(n))))))
218 ,mult(multBt(data_er.sliceRowM(n),tcv),inv(add(tcv,diag(S.sliceRow(n)))))).asVector();
219 outpos = output.insert( qcv[n][k], outpos);
220 outpos = output.insert(qmu[n][k], outpos);
221 break;
222 }
223 }
224 }
225 Matrix epdw = new AGDenseMatrix(ndata, K);
226
227 for (int n = 0; n < ndata; n++){
228 for (int k = 0; k < K; k++){
229 switch (cv_type){
230 case free:
231 epdw.set(n,k, trace(mult(inv(gcv_f[k]),qcv[n][k])) +
232 multATBA(qmu[n][k],inv(gcv_f[k])).asScalar()
233 );
234
235
236 C.set(n,k, epdw.get(n,k)-2*multBt(multAt(qmu[n][k],inv(gcv_f[k]))
237 ,gmu.sliceRowM(k)).get(0, 0)+multABAT(gmu.sliceRowM(k),inv(gcv_f[k])).get(0,0));
238 break;
239 case diagonal:
240 epdw.set(n,k,trace(mult(inv(diag(gcv_d.sliceRow(k))),qcv[n][k]))
241 +multATBA(qmu[n][k],inv(diag(gcv_d.sliceRow(k)))).asScalar());
242
243 C.set(n,k,epdw.get(n,k)-2*multBt(multAt(qmu[n][k],inv(diag(gcv_d.sliceRow(k))))
244 ,gmu.sliceRowM(k)).get(0, 0)+multABAT(gmu.sliceRowM(k),inv(diag(gcv_d.sliceRow(k)))).get(0, 0));
245 break;
246 case common:
247 Vector tqmu = new DenseVector(qmu[n][k]);
248 tqmu.add(-1, gmu.sliceRow(k));
249
250 C.set(n,k, tqmu.dot(tqmu)/gcv_c.get(k)+
251 trace(qcv[n][k])/gcv_c.get(k));
252
253
254
255
256
257 break;
258 }
259 }
260 }
261 }
262 Matrix n3 = null;
263 if (ndim_nr != 0){
264
265
266
267
268
269 switch (cv_type){
270 case free:
271
272 n3 = dist3_free(data_nr, gmu_nr, gcv_nr_f);
273
274 break;
275
276 case common:
277 n3 = dist3_common(data_nr, gmu_nr, gcv_nr_c);
278 break;
279
280 case diagonal:
281 n3 = dist3_diag(data_nr, gmu_nr, gcv_nr_d);
282 break;
283
284
285 }
286 }
287
288
289
290 if (datatype.get(0,1) !=0 && datatype.get(1,1) == 0 ){
291
292 a = (AGDenseMatrix) (repmatt(v, ndata, 1).add(ndim_nr).scale(0.5));
293 b = (AGDenseMatrix) (repmatt(v, ndata, 1).add(n3).scale(0.5));
294 C = (AGDenseMatrix) n3;
295 abpos = ab.insert(a,abpos);
296 abpos = ab.insert(b,abpos);
297
298
299 }
300 else if( datatype.get(1,1) !=0 && datatype.get(0,1) == 0){
301
302 a = (AGDenseMatrix) (repmatt(v, ndata, 1).add(ndim_er).scale(0.5));
303 b = (AGDenseMatrix) (repmatt(v, ndata, 1).add(C)).scale(0.5);
304 abpos = ab.insert(a,abpos);
305 abpos = ab.insert(b,abpos);
306
307 }
308 else if (datatype.get(0,1)!=0 && datatype.get(1,1) !=0){
309
310 a = (AGDenseMatrix) (repmatt(v, ndata, 1).add(ndim_nr+ndim_er).scale(0.5));
311 b = (AGDenseMatrix) (repmatt(v, ndata, 1).add(C).add(n3).scale(0.5));
312 abpos = ab.insert(a,abpos);
313 abpos = ab.insert(b,abpos);
314 C.add(n3);
315 }
316
317 Matrix aux1 = new AGDenseMatrix(ndata,K), aux2 = new AGDenseMatrix(ndata,K), aux3 = new AGDenseMatrix(ndata,K), aux4 = new AGDenseMatrix(ndata,K),
318 aux5 = new AGDenseMatrix(ndata,K), aux6 = new AGDenseMatrix(ndata,K),
319 aux7 = new AGDenseMatrix(ndata,K), aux8 = new AGDenseMatrix(ndata,K), aux9 = new AGDenseMatrix(ndata,K), aux10 = new AGDenseMatrix(ndata,K),
320 aux11 = new AGDenseMatrix(ndata,K);
321 aux1 = zeros(ndata, K);
322 aux2 = zeros(ndata, K);
323 aux3 = zeros(ndata, K);
324 aux4 = zeros(ndata, K);
325 aux5 = zeros(ndata, K);
326 aux6 = zeros(ndata, K);
327 aux9 = zeros(ndata, K);
328 aux10 = zeros(ndata, K);
329 aux11 = zeros(ndata, K);
330
331
332
333
334 if (ndim_er != 0){
335
336 for (int n = 0; n < ndata; n++){
337 for (int k = 0; k <K; k++){
338 switch (cv_type){
339 case free:
340 aux1.set(n,k, -ndim_er/2.0*log(2*PI)-log(det(diag(S.sliceRow(n))))/2.0-
341 (multBt(mult(data_er.sliceRowM(n),inv(diag(S.sliceRow(n)))),data_er.sliceRowM(n)).asScalar() -
342 multBt(multAt(qmu[n][k],inv(diag(S.sliceRow(n)))),data_er.sliceRowM(n)).asScalar()*2 +
343 trace(mult(inv(diag(S.sliceRow(n))),qcv[n][k]))
344 + multATBA(qmu[n][k],inv(diag(S.sliceRow(n)))).asScalar()
345 )/2.0);
346 aux2.set(n,k, -ndim_er/2.0*log(2*Math.PI)-log(det(gcv_f[k]
347 ))/2.0+ndim_er/2.0*(psi(a.get(n,k))-log(b.get(n,k)))-
348 0.5*a.get(n,k)/b.get(n,k)*(
349 multATBA(qmu[n][k],inv(gcv_f[k])).asScalar()+ (trace(mult(inv(gcv_f[k]),qcv[n][k])))
350 -2*multBt(multAt(qmu[n][k],inv(gcv_f[k])),gmu.sliceRowM(k)).asScalar()
351 + multABAT(gmu.sliceRowM(k),inv(gcv_f[k])).asScalar() ));
352
353 aux4.set(n,k, ndim_er/2.0*log(2*Math.PI)+log(det(qcv[n][k]))/2.0+ndim_er/2.0);
354 break;
355 case diagonal:
356 aux1.set(n,k, -ndim_er/2.0*log(2*Math.PI)-log(det(diag(S.sliceRow(n))))/2.0-
357 (multABAT(data_er.sliceRowM(n),inv(diag(S.sliceRow(n)))).asScalar() -
358 2*multBt(multAt(qmu[n][k],inv(diag(S.sliceRow(n)))),data_er.sliceRowM(n)).asScalar() +
359 trace(mult(inv(diag(S.sliceRow(n))),qcv[n][k]))
360 + multATBA(qmu[n][k],inv(diag(S.sliceRow(n)))).asScalar() )/2.0);
361 Vector covk = gcv_d.sliceRow(k);
362 aux2.set(n,k, -ndim_er/2.0*log(2*Math.PI)-sum(log(covk))/2.0 +
363 ndim_er/2.0*(psi(a.get(n,k))-log(b.get(n,k)))-
364 0.5*a.get(n,k)/b.get(n,k)*C.get(n,k));
365
366 aux4.set(n,k, ndim_er/2.0*log(2*Math.PI)+log(det(qcv[n][k]))/2.0+ndim_er/2.0);
367 break;
368 case common:
369 aux1.set(n,k, -ndim_er/2.0*log(2*Math.PI)-log(det(diag(S.sliceRow(n))))/2.0-
370 (multABAT(data_er.sliceRowM(n),inv(diag(S.sliceRow(n)))).asScalar() -
371 2*multBt(multAt(qmu[n][k],inv(diag(S.sliceRow(n)))),data_er.sliceRowM(n)).asScalar() +
372 trace(mult(inv(diag(S.sliceRow(n))),qcv[n][k]))
373 + multATBA(qmu[n][k],inv(diag(S.sliceRow(n)))).asScalar()
374 )/2.0);
375 double covkd = gcv_c.get(k);
376 aux2.set(n,k, -ndim_er/2.0*log(2*Math.PI)-ndim_er*log(covkd)/2.0 +
377 ndim_er/2.0*(psi(a.get(n,k))-log(b.get(n,k)))-
378 0.5*a.get(n,k)/b.get(n,k)*C.get(n,k));
379
380 aux4.set(n,k, ndim_er/2.0*log(2*Math.PI)+log(det(qcv[n][k]))/2.0+ndim_er/2.0);
381 break;
382 default:
383 throw new IllegalArgumentException( "Unknown noise/covariance model");
384 }
385
386 aux3.set(n,k, v.get(k)/2.0*log(v.get(k)/2.0)+(v.get(k)/2.0-1)*(psi(a.get(n,k))-log(b.get(n,k)))
387 -v.get(k)/2.0*a.get(n,k)/b.get(n,k)-log(gamma(v.get(k)/2.0)));
388
389 aux5.set(n,k, -((a.get(n,k)-1)*psi(a.get(n,k))+log(b.get(n,k))-a.get(n,k)-
390 log(gamma(a.get(n,k)))));
391 }
392 }
393 }
394 if (ndim_nr != 0){
395
396
397 for(int n = 0; n <ndata; n++){
398 for (int k = 0; k <K; k++){
399
400 if (ndim_er != 0){
401
402 switch (cv_type){
403 case diagonal:
404 Vector covkd = gcv_nr_d.sliceRow(k);
405 aux6.set(n,k, -ndim_nr/2.0*log(2*Math.PI)-sum(log(covkd))/2.0+
406 ndim_nr/2.0*(psi(a.get(n,k))-log(b.get(n,k)))-
407 0.5*a.get(n,k)/b.get(n,k)*n3.get(n,k));
408 break;
409 case common:
410 double covkc = gcv_nr_c.get(k);
411 aux6.set(n,k, -ndim_nr/2.0*log(2*Math.PI)-ndim_nr*log(covkc)/2.0+
412 ndim_nr/2.0*(psi(a.get(n,k))-log(b.get(n,k)))-
413 0.5*a.get(n,k)/b.get(n,k)*n3.get(n,k));
414 break;
415 case free:
416 aux6.set(n,k, -ndim_nr/2.0*log(2*Math.PI)-log(det(
417 gcv_nr_f[k]))/2.0+ndim_nr/2.0*(psi(a.get(n,k))
418 -log(b.get(n,k)))-0.5*a.get(n,k)/b.get(n,k)*n3.get(n,k));
419 break;
420 default:
421 throw new IllegalArgumentException( "Unknown noise model");
422 }
423 }
424 if (ndim_er == 0){
425
426 switch (cv_type){
427 case common:
428
429 double covk = gcv_nr_c.get(k);
430 aux6.set(n,k, log(gamma((v.get(k)+ndim_nr)/2.0))-
431 ndim_nr/2.0*log(covk)-ndim_nr/2.0*log(Math.PI*v.get(k))
432 -log(gamma(v.get(k)/2.0))-(v.get(k) + ndim_nr)/2.0*
433 log(1+n3.get(n,k)/(covk*v.get(k))));
434 break;
435 case diagonal:
436
437 Vector covkd = gcv_nr_d.sliceRow(k);
438
439 double dist = sum(divide(pow(sub(data_nr.sliceRowM(n),gmu_nr.sliceRowM(k)),2.0).asVector() , covkd));
440 aux6.set(n,k, log(gamma((v.get(k)+ndim_nr)/2.0))-
441 sum(log(covkd))/2.0-
442 ndim_nr/2.0*log(Math.PI * v.get(k))-log(gamma(v.get(k)/2.0))-
443 (v.get(k) + ndim_nr)/2.0*log(1 + dist/v.get(k)));
444 break;
445 case free:
446 Matrix covkf = gcv_nr_f[k];
447 aux6.set(n,k, log(gamma((v.get(k)+ndim_nr)/2.0))-log(det(
448 covkf))/2.0-ndim_nr/2.0*log(Math.PI*v.get(k))-log(gamma(v.get(k)/2.0))-
449 ((v.get(k)+ndim_nr)/2.0)*log(1+n3.get(n,k)/v.get(k)));
450 break;
451 }
452 }
453 }
454 }
455 }
456
457
458 Vector sgm;
459 Vector datai;
460 if (ndim_bin !=0){
461
462
463
464 for (int i = 0 ; i < ndata; i++){
465 for (int j = 0 ; j < K; j++){
466 sgm = add(bp.sliceRow(j), eps);
467 datai = data_bin.sliceRow(i);
468 aux9.set(i,j, sum(add(times(datai,log(sgm)),times(sub(1.0,datai),log(sub(1.0,sgm))))));
469 }
470 }
471 }
472
473 if (ndim_mul != 0){
474
475
476 for(int n = 0 ; n < ndata; n++){
477 for(int k = 0 ; k< K; k++){
478 Vector temp = add(mp.sliceRow(k),eps);
479 Vector log_sgm = log(temp);
480 Vector datat = data_mul.sliceRow(n);
481 aux10.set(n,k, sum(times(datat,log_sgm) ));
482 }
483 }
484 }
485
486 if (ndim_int !=0){
487
488 for(int t=0 ; t <ndata; t++){
489 for (int i = 0 ; i < K; i++){
490 sgm = ip.sliceRow(i);
491 datai = data_int.sliceRow(t);
492 double poiss = sum(sub(times(datai,log(sgm)),sgm));
493 aux11.set(t,i, poiss-sum(log(seq(ndim_int))));
494 }
495 }
496 }
497
498 Matrix aux = add(add(add(add(add(add(add(add(aux1 , aux2) ,aux3), aux4), aux5), aux6), aux9), aux10) , aux11);
499
500 AGDenseMatrix q = (AGDenseMatrix) times(multBt(ones(ndata, 1),new AGDenseMatrix(p)),exp(aux));
501 Matrix s = new AGDenseMatrix(sum(q, 2));
502
503 for (MatrixEntry matrixEntry : s) {
504 if(matrixEntry.get() == 0.0)
505 {
506 matrixEntry.set(1.0);
507 }
508 }
509 q = (AGDenseMatrix) divide(q,mult(s,ones(1, K)));
510
511 Retval retval = new Retval(output, ab, q, C);
512
513 return retval;
514
515 }
516
517
518
519
520
521
522 }
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544