View Javadoc

1   /*
2    * $Id: ClusterEStepFull.java,v 1.6 2010/01/11 21:22:46 pah Exp $
3    * 
4    * Created on 11 Dec 2008 by Paul Harrison (paul.harrison@manchester.ac.uk)
5    * Copyright 2008 Astrogrid. All rights reserved.
6    *
7    * This software is published under the terms of the Astrogrid 
8    * Software License, a copy of which has been included 
9    * with this distribution in the LICENSE.txt file.  
10   *
11   */ 
12  
13  package org.astrogrid.cluster.cluster;
14  
15  import no.uib.cipr.matrix.AGDenseMatrix;
16  import no.uib.cipr.matrix.DenseVector;
17  import no.uib.cipr.matrix.MatrixEntry;
18  import no.uib.cipr.matrix.Vector;
19  
20  import org.astrogrid.matrix.Matrix;
21  import static org.astrogrid.matrix.MatrixUtils.*;
22  import static org.astrogrid.matrix.Algorithms.*;
23  import static java.lang.Math.*;
24  
25  
26  
27  /**
28   * @author Paul Harrison (paul.harrison@manchester.ac.uk) 11 Dec 2008 - translated  from matlab clustering_e_step_full
29   * @version $Name:  $
30   * @since VOTech Stage 8
31   */
32  public class ClusterEStepFull {
33  
34      // this is the E-step of the variational EM algoorithm for Louisa's data. In
35      // the data set, some variables are continuous with measurement errors, some
36      // variables are continous without measurement errors, some variables are
37      // binary, in the parameters, these data are separated as data, data_noerr,
38      // bin_data, S is the error associated
39      //function [output ab q C] = clustering_e_step_full(data, datatype, K, latent,...
40      //    ab, mu, cv, lmu, lcv, p, cv_type)
41      
42      public static class Retval {
43          public final AGDenseMatrix output;
44          public final AGDenseMatrix ab;
45          public final AGDenseMatrix q;
46          public final AGDenseMatrix C;
47          
48          public Retval(AGDenseMatrix output, AGDenseMatrix ab, AGDenseMatrix q, AGDenseMatrix C) {
49              this.output = output;
50              this.ab = ab;
51              this.q = q; 
52              this.C = C;
53          }
54      }
55      
56      public static Retval clustering_e_step_full(Matrix data, Matrix datatype, int K, Matrix latent,
57              AGDenseMatrix ab, Matrix mu, Matrix cv, Matrix lmu, Matrix lcv, Vector p, CovarianceKind cv_type
58             
59      ){
60      //%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
61      // E-step of the VB method
62      //%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
63      int ndata = data.numRows();
64      int no_of_data_types = datatype.numRows();
65      int outpos = 0; //the position of the pointer in the output array.
66      int abpos= 0;
67  
68      int n0 = 0, n1 = 0, nm = 0, ne = 0, d = 0;
69      Matrix data_nr = null, data_er = null, data_bin = null, data_mul = null, data_int = null;
70      Matrix gmu_nr = null,gcv_nr_f[] = new AGDenseMatrix[K], gcv_nr_d = null,  gmu = null, gcv_f[] = new AGDenseMatrix[K], gcv_d = null, S = null;
71      Vector  gcv_nr_c= null, gcv_c = null;
72      Matrix bp = null, mp = null, ip = null;
73      AGDenseMatrix C = new AGDenseMatrix(ndata, K);
74     
75      int ndim_nr = 0, ndim_er = 0,ndim_bin=0,ndim_mul = 0,ndim_int = 0;
76      for (int i = 0; i < no_of_data_types; i++){
77          if (datatype.get(i,0) == 1){     // continuous data without errors
78              if((ndim_nr = (int) datatype.get(i,1)) > 0){
79              data_nr = data.sliceCol(d,ndim_nr );
80              gmu_nr = reshape(mu.asVector(nm,nm+K*ndim_nr-1), K, ndim_nr);
81              nm = nm + K*ndim_nr;
82              gcv_nr_d = new AGDenseMatrix(K,ndim_nr);
83             switch (cv_type){
84                  case free:
85                      for (int k = 0; k < K; k++){
86                          gcv_nr_f[k] = reshape(cv.asVector(n0,n0+ndim_nr*ndim_nr -1),
87                              ndim_nr, ndim_nr);
88                          n0 = n0 + ndim_nr *ndim_nr;
89                      }
90                      break;
91                  case diagonal:
92                      for( int k=0; k < K; k++){
93                          gcv_nr_d.setRow(k,  reshape(cv.asVector(n0,n0+ndim_nr -1), 1, ndim_nr));
94                          n0 = n0 + ndim_nr;
95                     }
96                      break;
97                  case common:
98                      gcv_nr_c = cv.asVector(n0);
99                      n0 = n0 + K;
100                     break;
101             }
102             }
103             d = d + ndim_nr;
104         }
105         else if (datatype.get(i,0) == 2){ // continous data with errors
106             if((ndim_er = (int) datatype.get(i,1))!=0){
107             data_er = data.sliceCol(d,ndim_er);
108             gmu = reshape(lmu.asVector(ne,ne+K*ndim_er-1),K, ndim_er);
109             ne = ne + K*ndim_er;
110             gcv_d = new AGDenseMatrix(K,ndim_er);
111             switch (cv_type) {
112                 case free:
113                     for (int k = 0; k < K ; k++){
114                         gcv_f[k] = reshape(lcv.asVector(n1,n1+ndim_er*ndim_er-1), 
115                             ndim_er,ndim_er);
116                         n1 = n1 + ndim_er * ndim_er;
117                     }
118                     break;
119                 case diagonal:
120                     for(int k=0; k < K; k++){
121                         gcv_d.setRow(k,reshape(lcv.asVector(n1,n1+ndim_er-1), 1, ndim_er));
122                         n1 = n1 + ndim_er;
123                        }
124                     break;
125                 case common:
126                     gcv_c = lcv.asVector(n1);
127                 break;
128             }
129             d = d + ndim_er;
130             }
131         }
132         else if (datatype.get(i,0) == 3){
133              ndim_bin = (int) datatype.get(i,1);
134             data_bin = data.sliceCol(d,ndim_bin);
135             bp = reshape(mu.asVector(nm,nm+K*ndim_bin-1), K, ndim_bin);
136             nm = nm + K*ndim_bin;
137             d = d  + ndim_bin;
138         }
139         else if (datatype.get(i,0) == 4){
140              ndim_mul = (int) datatype.get(i,1);
141             data_mul = data.sliceCol(d, ndim_mul );
142             mp = reshape(mu.asVector(nm,nm+K*ndim_mul -1), K, ndim_mul);
143             nm = nm + K*ndim_mul;
144             d =  d + ndim_mul;
145         }
146         else if (datatype.get(i,0) == 5){
147             ndim_int = (int) datatype.get(i,1);       
148             data_int = data.sliceCol(d,ndim_int);
149             ip = reshape(mu.asVector(nm,nm+K*ndim_int-1), K, ndim_int);
150             nm = nm + K*ndim_int;
151             d= d + ndim_int;
152         }
153         else if (datatype.get(i,0) == 6){
154             int ndim_error = (int) datatype.get(i,1);
155             if (ndim_error != ndim_er){
156                 throw new IllegalArgumentException( "The dimension of measurement errors and ");
157             }        
158             S = data.sliceCol(d,ndim_error);
159             // error inforamtion
160             S.add(ones(ndata, ndim_er).scale(1.0e-6));
161             d = d + ndim_error;
162         }
163     }   
164 
165     AGDenseMatrix a = null, b = null;
166     Vector v = null;
167     if (ndim_er != 0 || ndim_nr != 0){
168         a = reshape(ab.asVector(0,ndata*K-1), ndata, K);
169         b = reshape(ab.asVector(ndata*K), ndata, K);
170         v = reshape(latent, K, 1).asVector();
171     }
172 
173     
174     //--------------------------------------------------------------------------
175     // THE POSTERIOR OF THE LATENT VARIABLES INCLUDING U AND W FOR THE VARIABLES
176     // WITH MEASUREMENT ERRORS
177     //--------------------------------------------------------------------------
178     int ndim= data.numColumns();       // data is 
179     AGDenseMatrix output = new AGDenseMatrix(ndim_er*(ndim_er+1)*K*ndata,1);
180 
181     AGDenseMatrix qcv[][]= new AGDenseMatrix[ndata][K];
182     DenseVector qmu[][]= new DenseVector[ndata][K]; //actually always a vector
183     if (ndim_er != 0) {
184         // (1) the centre of the auxiliary distribution q
185         Matrix epu = divide(a , b);         // the expectation of u
186         // qmu -- the centre of the q-distribution w.r.t w
187         // qcv -- the covariance of the q-distribution w.r.t w
188         for (int n=0 ; n <ndata; n++){
189             for (int k =0; k <K; k++){  
190                 switch (cv_type){
191                     case free:
192                         Matrix gcv_fe = new AGDenseMatrix(gcv_f[k]);
193                         gcv_fe.scale(1.0/epu.get(n,k));
194                         qcv[n][k] = mult(mult(diag(S.sliceRow(n)),inv(add(diag(S.sliceRow(n)),gcv_fe))),gcv_fe);
195                         qmu[n][k] = add(mult(gmu.sliceRowM(k),mult(diag(S.sliceRow(n)),inv(add(diag(S.sliceRow(n)),gcv_fe))))
196                             
197                                 ,mult(data_er.sliceRowM(n),mult(gcv_fe,inv(add(diag(S.sliceRow(n)),gcv_fe))))).asVector();
198                         outpos = output.insert( qcv[n][k], outpos);
199                         outpos = output.insert(qmu[n][k], outpos);
200                         
201                         break;
202                     case diagonal:
203                         Matrix gcv_scale = (Matrix) diag(gcv_d.sliceRow(k)).scale(1.0/epu.get(n,k));
204                         qcv[n][k]=mult(mult(diag(S.sliceRow(n)),inv(add(gcv_scale,
205                             diag(S.sliceRow(n))))),gcv_scale);
206                         qmu[n][k]=add(multBt(gmu.sliceRowM(k),mult(diag(S.sliceRow(n)),inv(add(gcv_scale,diag(S.sliceRow(n))))))
207                         ,multBt(data_er.sliceRowM(n),mult(gcv_scale,inv(add(gcv_scale,diag(S.sliceRow(n))))))).asVector();
208                         outpos = output.insert( qcv[n][k], outpos);
209                         outpos = output.insert(qmu[n][k], outpos);
210                         break;
211                     case common:                    
212                         no.uib.cipr.matrix.Matrix tcv = eye(ndim_er).scale(gcv_c.get(k)/epu.get(n,k));
213                         qcv[n][k]=mult(diag(S.sliceRow(n)).scale((gcv_c.get(k))/epu.get(n,k)),inv(
214                             diag(S.sliceRow(n)).add(gcv_c.get(k)/epu.get(n,k))));
215     //                     qcv[n][k] = diag(S.sliceRow(n))*inv(tcv/
216     //                         epu.get(n,k)+diag(S.sliceRow(n)))*tcv/epu.get(n,k);
217                         qmu[n][k]=add(multBt(gmu.sliceRowM(k),mult(diag(S.sliceRow(n)),inv(add(tcv ,diag(S.sliceRow(n))))))
218                                 ,mult(multBt(data_er.sliceRowM(n),tcv),inv(add(tcv,diag(S.sliceRow(n)))))).asVector();
219                         outpos = output.insert( qcv[n][k], outpos);
220                         outpos = output.insert(qmu[n][k], outpos);
221                         break;
222                 }
223             }
224         }
225         Matrix epdw = new AGDenseMatrix(ndata, K);
226         // update q-distribution w.r.t u
227         for (int n = 0; n < ndata; n++){
228             for (int k = 0; k < K; k++){
229                 switch (cv_type){
230                     case free:
231                         epdw.set(n,k, trace(mult(inv(gcv_f[k]),qcv[n][k])) +      
232                                 multATBA(qmu[n][k],inv(gcv_f[k])).asScalar()
233                             );
234                         // the expectation of w^T Sigma^{-1}w
235                         
236                         C.set(n,k, epdw.get(n,k)-2*multBt(multAt(qmu[n][k],inv(gcv_f[k]))
237                             ,gmu.sliceRowM(k)).get(0, 0)+multABAT(gmu.sliceRowM(k),inv(gcv_f[k])).get(0,0));
238                         break;
239                     case diagonal:
240                         epdw.set(n,k,trace(mult(inv(diag(gcv_d.sliceRow(k))),qcv[n][k]))
241                             +multATBA(qmu[n][k],inv(diag(gcv_d.sliceRow(k)))).asScalar());
242                         // the expectation of w^T Sigma^{-1}w
243                         C.set(n,k,epdw.get(n,k)-2*multBt(multAt(qmu[n][k],inv(diag(gcv_d.sliceRow(k))))
244                             ,gmu.sliceRowM(k)).get(0, 0)+multABAT(gmu.sliceRowM(k),inv(diag(gcv_d.sliceRow(k)))).get(0, 0));
245                         break;
246                     case common:
247                         Vector tqmu = new DenseVector(qmu[n][k]);
248                         tqmu.add(-1, gmu.sliceRow(k));
249                         
250                         C.set(n,k, tqmu.dot(tqmu)/gcv_c.get(k)+ 
251                             trace(qcv[n][k])/gcv_c.get(k));
252     //                     epdw(n,k)=trace(qcv[n][k])/gcv(k)
253     //                         +qmu[n][k]'*qmu[n][k]/gcv(k);
254     //                     // the expectation of w^T Sigma^{-1}w
255     //                     C(n,k)=epdw(n,k)-2*qmu[n][k]'*gmu(k,:)'/gcv(k)
256     //                         +gmu(k,:)*gmu(k,:)'/gcv(k);
257                         break;
258                 }
259             }
260         }
261     }
262     Matrix n3 = null;
263     if (ndim_nr != 0){
264         //----------------------------------------------------------------------
265         // THE POSTERIOR OF THE LATENT VARIABLE U FOR THE DATA WITH NO ERROR
266         //----------------------------------------------------------------------
267 
268         //gcv_nr has different forms for the different models
269         switch (cv_type){
270         case free: 
271 
272             n3 = dist3_free(data_nr, gmu_nr, gcv_nr_f);
273 
274             break;
275             
276         case common:
277             n3 = dist3_common(data_nr, gmu_nr, gcv_nr_c);
278            break;
279            
280         case diagonal:
281             n3 = dist3_diag(data_nr, gmu_nr, gcv_nr_d);
282            break;
283            
284 
285         }
286     }
287 
288 
289    
290     if (datatype.get(0,1) !=0 && datatype.get(1,1) == 0 ){// continuous data without errors
291     //     disp 'estimate parameters of u without measurement errors'
292         a = (AGDenseMatrix) (repmatt(v, ndata, 1).add(ndim_nr).scale(0.5));
293         b = (AGDenseMatrix) (repmatt(v, ndata, 1).add(n3).scale(0.5));
294         C = (AGDenseMatrix) n3;
295         abpos = ab.insert(a,abpos);
296         abpos = ab.insert(b,abpos);
297 
298     //     pause
299     }
300     else if( datatype.get(1,1) !=0 && datatype.get(0,1) == 0){
301     //     disp 'estimate parameter of u for variable with measurement errors'
302         a = (AGDenseMatrix) (repmatt(v, ndata, 1).add(ndim_er).scale(0.5));
303         b = (AGDenseMatrix) (repmatt(v, ndata, 1).add(C)).scale(0.5);
304         abpos = ab.insert(a,abpos);
305         abpos = ab.insert(b,abpos);
306     //     C = C;
307     }
308     else if (datatype.get(0,1)!=0 && datatype.get(1,1) !=0){
309     //     disp 'estimate parameter of u for combined variables'
310         a = (AGDenseMatrix) (repmatt(v, ndata, 1).add(ndim_nr+ndim_er).scale(0.5));
311         b = (AGDenseMatrix) (repmatt(v, ndata, 1).add(C).add(n3).scale(0.5));
312         abpos = ab.insert(a,abpos);
313         abpos = ab.insert(b,abpos);
314         C.add(n3);
315     }
316 
317     Matrix aux1 = new AGDenseMatrix(ndata,K), aux2 = new AGDenseMatrix(ndata,K), aux3 = new AGDenseMatrix(ndata,K), aux4 = new AGDenseMatrix(ndata,K), 
318     aux5 = new AGDenseMatrix(ndata,K), aux6 = new AGDenseMatrix(ndata,K), 
319      aux7 = new AGDenseMatrix(ndata,K), aux8 = new AGDenseMatrix(ndata,K), aux9 = new AGDenseMatrix(ndata,K), aux10 = new AGDenseMatrix(ndata,K), 
320      aux11 = new AGDenseMatrix(ndata,K);
321     aux1 = zeros(ndata, K);
322     aux2 = zeros(ndata, K);
323     aux3 = zeros(ndata, K);
324     aux4 = zeros(ndata, K);
325     aux5 = zeros(ndata, K);
326     aux6 = zeros(ndata, K);
327     aux9 = zeros(ndata, K);
328     aux10 = zeros(ndata, K);
329     aux11 = zeros(ndata, K);
330 
331     //--------------------------------------------------------------------------
332     //   CALCULATE THE RESPONSIBILITIES
333     //--------------------------------------------------------------------------
334     if (ndim_er != 0){
335     //     disp 'calculate components of continuous variable with errors for responsibilities'
336         for (int n = 0; n < ndata; n++){
337             for (int k = 0; k <K; k++){
338                 switch (cv_type){
339                     case free:                  
340                         aux1.set(n,k, -ndim_er/2.0*log(2*PI)-log(det(diag(S.sliceRow(n))))/2.0-
341                             (multBt(mult(data_er.sliceRowM(n),inv(diag(S.sliceRow(n)))),data_er.sliceRowM(n)).asScalar() -  
342                             multBt(multAt(qmu[n][k],inv(diag(S.sliceRow(n)))),data_er.sliceRowM(n)).asScalar()*2 + 
343                             trace(mult(inv(diag(S.sliceRow(n))),qcv[n][k]))
344                             + multATBA(qmu[n][k],inv(diag(S.sliceRow(n)))).asScalar()
345                             )/2.0);
346                         aux2.set(n,k, -ndim_er/2.0*log(2*Math.PI)-log(det(gcv_f[k]
347                             ))/2.0+ndim_er/2.0*(psi(a.get(n,k))-log(b.get(n,k)))-
348                             0.5*a.get(n,k)/b.get(n,k)*(                            
349                             multATBA(qmu[n][k],inv(gcv_f[k])).asScalar()+ (trace(mult(inv(gcv_f[k]),qcv[n][k])))
350                             -2*multBt(multAt(qmu[n][k],inv(gcv_f[k])),gmu.sliceRowM(k)).asScalar()
351                             + multABAT(gmu.sliceRowM(k),inv(gcv_f[k])).asScalar() ));
352                         // -E_q[ LOG q(W|K) ]
353                         aux4.set(n,k, ndim_er/2.0*log(2*Math.PI)+log(det(qcv[n][k]))/2.0+ndim_er/2.0);     
354                             break;
355                     case diagonal:
356                         aux1.set(n,k, -ndim_er/2.0*log(2*Math.PI)-log(det(diag(S.sliceRow(n))))/2.0-
357                             (multABAT(data_er.sliceRowM(n),inv(diag(S.sliceRow(n)))).asScalar() - 
358                             2*multBt(multAt(qmu[n][k],inv(diag(S.sliceRow(n)))),data_er.sliceRowM(n)).asScalar() +
359                             trace(mult(inv(diag(S.sliceRow(n))),qcv[n][k])) 
360                             + multATBA(qmu[n][k],inv(diag(S.sliceRow(n)))).asScalar() )/2.0);                    
361                         Vector covk = gcv_d.sliceRow(k);                    
362                         aux2.set(n,k, -ndim_er/2.0*log(2*Math.PI)-sum(log(covk))/2.0 +
363                             ndim_er/2.0*(psi(a.get(n,k))-log(b.get(n,k)))-
364                             0.5*a.get(n,k)/b.get(n,k)*C.get(n,k));
365                         // -E_q[ LOG q(W|K) ]
366                         aux4.set(n,k, ndim_er/2.0*log(2*Math.PI)+log(det(qcv[n][k]))/2.0+ndim_er/2.0);
367                         break;
368                     case common:
369                         aux1.set(n,k, -ndim_er/2.0*log(2*Math.PI)-log(det(diag(S.sliceRow(n))))/2.0-
370                             (multABAT(data_er.sliceRowM(n),inv(diag(S.sliceRow(n)))).asScalar() - 
371                             2*multBt(multAt(qmu[n][k],inv(diag(S.sliceRow(n)))),data_er.sliceRowM(n)).asScalar() +
372                             trace(mult(inv(diag(S.sliceRow(n))),qcv[n][k])) 
373                             + multATBA(qmu[n][k],inv(diag(S.sliceRow(n)))).asScalar()
374                              )/2.0);
375                         double covkd = gcv_c.get(k);
376                         aux2.set(n,k, -ndim_er/2.0*log(2*Math.PI)-ndim_er*log(covkd)/2.0 +
377                             ndim_er/2.0*(psi(a.get(n,k))-log(b.get(n,k)))-
378                             0.5*a.get(n,k)/b.get(n,k)*C.get(n,k));
379                         // -E_q[ LOG q(W|K) ]
380                         aux4.set(n,k, ndim_er/2.0*log(2*Math.PI)+log(det(qcv[n][k]))/2.0+ndim_er/2.0);
381                         break;
382                     default:
383                         throw new IllegalArgumentException( "Unknown noise/covariance model");
384                 }
385                 // E_q [LOG(P(U|K))]
386                 aux3.set(n,k, v.get(k)/2.0*log(v.get(k)/2.0)+(v.get(k)/2.0-1)*(psi(a.get(n,k))-log(b.get(n,k)))
387                     -v.get(k)/2.0*a.get(n,k)/b.get(n,k)-log(gamma(v.get(k)/2.0)));
388                 // -E_q[ LOG q(U|K) ]
389                 aux5.set(n,k, -((a.get(n,k)-1)*psi(a.get(n,k))+log(b.get(n,k))-a.get(n,k)-
390                     log(gamma(a.get(n,k)))));
391             }
392         }
393     }
394     if (ndim_nr != 0){
395     //     disp 'calculate components of continus variable without errors for responsibilities'
396         // THE SECOND PART IS RELATED TO THE DATA WITHOUT MEASUREMENT ERRORS
397         for(int n = 0; n <ndata; n++){
398             for (int k = 0; k <K; k++){
399                
400                 if (ndim_er != 0){
401                     // E_q[ log(p(t|u, k)) ]
402                     switch (cv_type){
403                         case diagonal:
404                             Vector covkd = gcv_nr_d.sliceRow(k);
405                             aux6.set(n,k, -ndim_nr/2.0*log(2*Math.PI)-sum(log(covkd))/2.0+
406                                 ndim_nr/2.0*(psi(a.get(n,k))-log(b.get(n,k)))-
407                                 0.5*a.get(n,k)/b.get(n,k)*n3.get(n,k));
408                             break;
409                         case common:
410                             double covkc = gcv_nr_c.get(k);
411                             aux6.set(n,k, -ndim_nr/2.0*log(2*Math.PI)-ndim_nr*log(covkc)/2.0+
412                                 ndim_nr/2.0*(psi(a.get(n,k))-log(b.get(n,k)))-
413                                 0.5*a.get(n,k)/b.get(n,k)*n3.get(n,k));
414                             break;
415                         case free:
416                             aux6.set(n,k, -ndim_nr/2.0*log(2*Math.PI)-log(det(
417                                 gcv_nr_f[k]))/2.0+ndim_nr/2.0*(psi(a.get(n,k))
418                                 -log(b.get(n,k)))-0.5*a.get(n,k)/b.get(n,k)*n3.get(n,k));
419                             break;
420                         default:
421                             throw new IllegalArgumentException( "Unknown noise model");
422                     }
423                 }
424                 if (ndim_er == 0){ 
425     //                 disp 'e-step, ndim_er == 0'
426                     switch (cv_type){
427                         case common:
428     //                         disp 'e-step, ndim_er ~=0, spherical'
429                             double covk = gcv_nr_c.get(k);
430                             aux6.set(n,k, log(gamma((v.get(k)+ndim_nr)/2.0))- 
431                                 ndim_nr/2.0*log(covk)-ndim_nr/2.0*log(Math.PI*v.get(k)) 
432                                 -log(gamma(v.get(k)/2.0))-(v.get(k) + ndim_nr)/2.0* 
433                                 log(1+n3.get(n,k)/(covk*v.get(k))));
434                             break;
435                         case diagonal:
436     //                         disp 'e-step, ndim er ~=0, diagonal'
437                             Vector covkd = gcv_nr_d.sliceRow(k);
438                             
439                             double dist = sum(divide(pow(sub(data_nr.sliceRowM(n),gmu_nr.sliceRowM(k)),2.0).asVector() , covkd));
440                             aux6.set(n,k,  log(gamma((v.get(k)+ndim_nr)/2.0))-
441                                 sum(log(covkd))/2.0-
442                                 ndim_nr/2.0*log(Math.PI * v.get(k))-log(gamma(v.get(k)/2.0))-
443                                 (v.get(k) + ndim_nr)/2.0*log(1 + dist/v.get(k)));
444                             break;
445                         case free:
446                             Matrix covkf = gcv_nr_f[k];
447                             aux6.set(n,k, log(gamma((v.get(k)+ndim_nr)/2.0))-log(det(
448                                 covkf))/2.0-ndim_nr/2.0*log(Math.PI*v.get(k))-log(gamma(v.get(k)/2.0))-
449                                 ((v.get(k)+ndim_nr)/2.0)*log(1+n3.get(n,k)/v.get(k)));
450                             break;
451                     }
452                 }
453             }
454         }
455     }
456 
457     
458     Vector sgm;
459     Vector datai;
460     if (ndim_bin !=0){
461     //     disp 'calculate components of binary variable for responsibilities'
462         //  THE THIRD PART IS RELATED TO THE BINARY DATA
463         //  we need first calculate its mean
464         for (int i = 0 ; i < ndata; i++){
465             for (int j = 0 ; j < K; j++){
466                 sgm = add(bp.sliceRow(j), eps);
467                 datai = data_bin.sliceRow(i);
468                 aux9.set(i,j, sum(add(times(datai,log(sgm)),times(sub(1.0,datai),log(sub(1.0,sgm))))));
469             }
470         }
471     }
472 
473     if (ndim_mul != 0){
474     //     disp 'calculate components of category variable for responsibilities'
475     //     pause    
476         for(int n = 0 ; n < ndata; n++){
477             for(int k = 0 ; k< K; k++){
478                 Vector temp = add(mp.sliceRow(k),eps);
479                 Vector log_sgm = log(temp);
480                 Vector datat = data_mul.sliceRow(n);
481                 aux10.set(n,k, sum(times(datat,log_sgm) ));
482             }
483         }
484     }
485 
486     if (ndim_int !=0){
487     //     disp 'calculate components of int variable for responsibilities'
488         for(int t=0 ; t <ndata; t++){
489             for (int i = 0 ; i < K; i++){
490                 sgm = ip.sliceRow(i);
491                 datai = data_int.sliceRow(t);
492                 double poiss = sum(sub(times(datai,log(sgm)),sgm));
493                 aux11.set(t,i,  poiss-sum(log(seq(ndim_int))));//FIXME - put ndim_int instead of datai here possible mistake in original?
494             }
495         }
496     }
497 
498     Matrix aux = add(add(add(add(add(add(add(add(aux1 , aux2) ,aux3),  aux4), aux5), aux6), aux9), aux10) , aux11);
499 
500     AGDenseMatrix q = (AGDenseMatrix) times(multBt(ones(ndata, 1),new AGDenseMatrix(p)),exp(aux));
501     Matrix s = new AGDenseMatrix(sum(q, 2));
502     // Set any zeros to one before dividing 
503     for (MatrixEntry matrixEntry : s) {
504         if(matrixEntry.get() == 0.0)
505         {
506             matrixEntry.set(1.0);
507         }
508     }
509     q = (AGDenseMatrix) divide(q,mult(s,ones(1, K)));
510 
511     Retval retval = new Retval(output, ab, q, C);
512     
513     return retval;
514 
515     }
516     
517     
518     
519  
520     
521  
522 }
523 
524 /*
525  * $Log: ClusterEStepFull.java,v $
526  * Revision 1.6  2010/01/11 21:22:46  pah
527  * reasonable numerical stability and fidelity to MATLAB results achieved
528  *
529  * Revision 1.5  2010/01/05 21:27:13  pah
530  * basic clustering translation complete
531  *
532  * Revision 1.4  2009/09/20 17:18:01  pah
533  * checking just prior to bham visit
534  *
535  * Revision 1.3  2009/09/14 19:08:43  pah
536  * code runs clustering, but not giving same results as matlab exactly
537  *
538  * Revision 1.2  2009/09/08 19:23:30  pah
539  * got rid of npe and array bound problems....
540  *
541  * Revision 1.1  2009/09/07 16:06:11  pah
542  * initial transcription of the core
543  *
544  */