View Javadoc

1   /*
2    * $Id: ClusterMStep.java,v 1.3 2010/01/11 21:22:46 pah Exp $
3    * 
4    * Created on 21 Sep 2009 by Paul Harrison (paul.harrison@manchester.ac.uk)
5    * Copyright 2009 AstroGrid. All rights reserved.
6    *
7    * This software is published under the terms of the AstroGrid 
8    * Software License, a copy of which has been included 
9    * with this distribution in the LICENSE.txt file.  
10   *
11   */ 
12  
13  package org.astrogrid.cluster.cluster;
14  import no.uib.cipr.matrix.AGDenseMatrix;
15  import no.uib.cipr.matrix.DenseVector;
16  import no.uib.cipr.matrix.Vector;
17  
18  import org.astrogrid.matrix.Matrix;
19  import static org.astrogrid.matrix.MatrixUtils.*;
20  import static org.astrogrid.matrix.Algorithms.*;
21  import static java.lang.Math.*;
22  
23  public class ClusterMStep {
24      public static class Retval {
25          public final AGDenseMatrix mu;
26          public final AGDenseMatrix cv;
27          public final AGDenseMatrix lmu;
28          public final AGDenseMatrix lcv;
29          public final Vector p; 
30          public Retval(AGDenseMatrix mu,AGDenseMatrix cv, AGDenseMatrix lmu, AGDenseMatrix lcv, Vector p ) {
31  
32              this.mu = mu;
33              this.cv = cv;
34              this.lmu = lmu;
35              this.lcv = lcv;
36              this.p = p;            
37          }
38  
39      }
40  
41      // -------------------------------------------------------------------------
42      // Re-estimate the global parameters of the model
43      //function  [mu cv lmu lcv p]
44      //--------------------------------------------------------------------------
45      static Retval  clustering_m_step(Matrix alldata, Matrix datatype, int K, AGDenseMatrix q,
46              AGDenseMatrix lcv, CovarianceKind cv_type){
47  
48          ////%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
49          // M-step of the VB method
50          //%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
51          int ndata = alldata.numRows();
52          int no_of_data_types = datatype.numRows();
53          Matrix data_nr = null, data_bin = null, data_mul = null, data_int = null, data_er = null;
54          Vector gcv_c = new DenseVector(no_of_data_types);
55          Matrix   S = null;//note that this is only initialized to something sensible in the case of data errors.
56          int ndim_nr = 0, ndim_er = 0,ndim_bin = 0,ndim_mul = 0,ndim_int = 0;
57  
58          int n1 = 0, d = 0;
59          Matrix qcv[][] = new AGDenseMatrix[ndata][K];//FIXME these are probably not correct
60          DenseVector qmu[][] = new DenseVector[ndata][K];
61          Matrix gmu_nr = null,gcv_nr[] = new AGDenseMatrix[K], gcv_nr_d = null, gmu = null, gcv_f[] = new AGDenseMatrix[K], gcv_d = null;
62          for ( int i = 0; i < no_of_data_types; i++){
63              if(datatype.get(i,0) == 1     ){ // continuous data without errors
64                  if((ndim_nr = (int)datatype.get(i,1)) != 0){
65                  data_nr = alldata.sliceCol(d,ndim_nr);
66                  gcv_nr_d = new AGDenseMatrix(K,ndim_nr);
67                  d = d + ndim_nr;
68                  }
69              }
70              else if(datatype.get(i,0) == 2 ){ // continous data with errors
71                  if((ndim_er = (int)datatype.get(i,1)) != 0){
72                  data_er = alldata.sliceCol(d,ndim_er);
73                  gcv_d = new AGDenseMatrix(K, ndim_er);
74                  switch(cv_type) {
75  
76                  case free:
77                      for ( int k = 0; k < K; k++){
78                          gcv_f[k] = reshape(lcv.asVector(n1,n1+ndim_er*ndim_er-1), 
79                                  ndim_er,ndim_er);
80                          n1 = n1 + ndim_er*ndim_er;
81                      }
82                      break;
83                  case diagonal:
84                      for ( int k = 0; k < K; k++){
85                          gcv_d.setRow(k,  reshape(lcv.asVector(n1,n1+ndim_er-1), 1, ndim_er));
86                          n1 = n1 + ndim_er;
87                      }
88                      break;
89                  case common:
90                      gcv_c = lcv.asVector(n1); 
91                  }
92                  d = d + ndim_er;
93                  }
94              }
95              else if (datatype.get(i,0) == 3) {
96                  if((ndim_bin = (int)datatype.get(i,1)) != 0){
97                  data_bin = alldata.sliceCol(d,ndim_bin);
98                  d = d  + ndim_bin;
99                  }
100             }
101             else if (datatype.get(i,0) == 4) {
102                 if((ndim_mul =(int) datatype.get(i,1)) != 0){
103                 data_mul = alldata.sliceCol(d,ndim_mul);
104                 d = d + ndim_mul;
105                 }
106             }
107             else if (datatype.get(i,0) == 5) {
108                 if((ndim_int = (int)datatype.get(i,1)) != 0){       
109                 data_int = alldata.sliceCol(d,ndim_int);
110                 d = d + ndim_int;
111                 }
112             }
113             else if (datatype.get(i,0) == 6) {
114                 int ndim_error =(int) datatype.get(i,1);
115                 if(ndim_error != ndim_er){ //
116                     throw new IllegalArgumentException( "The dimension of measurement errors and ");
117                 }        
118                 S = alldata.sliceCol(d,ndim_error);
119                 // error inforamtion
120                 S.add(1.0e-6, ones(ndata, ndim_er));
121                 d = d + ndim_error;
122             }
123         }
124         // maximize for the parameters gmu, gcv for the data with measurement errors
125         // [ndata ndim] = size(data);
126         ////%%%%%%%%%%%%%%%%%%
127         //   gmu
128         //%%%%%%%%%%%%%%%%%%%
129         AGDenseMatrix mu = new AGDenseMatrix(0,0) ,cv = new AGDenseMatrix(0,0), lmu = new AGDenseMatrix(0,0);
130         lcv = new AGDenseMatrix(0,0);
131         gmu = new AGDenseMatrix(K, ndim_er);
132         if(ndim_er != 0  ){ //
133             for ( int k = 0; k < K; k++){
134                 AGDenseMatrix tmpg = zeros(ndim_er,1);
135                 AGDenseMatrix tmpt = zeros(ndim_er, ndim_er);
136                 for ( int n = 0; n < ndata; n++){
137                     switch(cv_type) {
138                     case free:
139                         tmpg.add(q.get(n,k),multBt(inv(add((gcv_f[k]),diag(  
140                                 S.sliceRow(n)))),data_er.sliceRowM(n)));
141                         tmpt.add(q.get(n,k),inv(add((gcv_f[k]),diag(S.sliceRow(n)))));
142                         break;
143                     case diagonal:
144                         tmpg.add( q.get(n,k),multBt(inv(add(diag(gcv_d.sliceRow(k)),diag(S.sliceRow(n))))
145                                 ,data_er.sliceRowM(n)));
146                         tmpt.add(q.get(n,k),inv(add(diag(gcv_d.sliceRow(k)),diag(S.sliceRow(n)))));
147                         break;
148                     case common:
149                         tmpg.add(q.get(n,k),multBt(inv(add(gcv_c.get(k),diag(S.sliceRow(n)))),data_er.sliceRowM(n)));
150                         tmpt.add(q.get(n,k),inv(add(gcv_c.get(k),diag(S.sliceRow(n)))));
151                         break;
152                     }
153                 }
154                 gmu.setRow(k , mult(inv(tmpt),tmpg).asVector());        
155             }
156             lmu.append( gmu.asVector());
157             //%%%%%%%%%%%%%%%%%%
158             //gcv
159             //%%%%%%%%%%%%%%%%%%%
160             for ( int k = 0; k < K; k++){
161                 switch(cv_type) {
162 
163                 case free:
164                     Matrix cvk = pow((gcv_f[k]),0.5);
165                     Vector tmpcv = Minimize.minimize("objectiveFull", 10, cvk.asVector(), q.sliceCol(k), 
166                             gmu.sliceRow(k), data_er, S);
167                     Matrix tmpcvm = pow(reshape(tmpcv,ndim_er,ndim_er),2.0);
168                     lcv.append(tmpcvm.asVector());
169                     break;
170                 case diagonal:
171 
172                     Vector cvkv = pow(gcv_d.sliceRow(k),0.5);               
173                     tmpcv =Minimize.minimize("objectiveDiag",10,cvkv,q.sliceCol(k),gmu.sliceRow(k),
174                             data_er, S);                
175                     lcv.append(pow(tmpcv,2));
176                     break;
177 
178                 case common:
179                     Vector cvk_c = new DenseVector(new double[]{gcv_c.get(k)});
180                     tmpcv = Minimize.minimize("objectiveSpherical",10,cvk_c,q.sliceCol(k), gmu.sliceRow(k),
181                             data_er, S);
182                     gcv_c.set(k , pow(tmpcv.get(0),2.0));
183                     lcv.append(gcv_c.get(k));
184                 }
185             }
186         }
187 
188         if(ndim_nr != 0){ 
189             // maximize for the parameters gmu, gcv for the data without errors
190 
191             //%%%%%%%%%%%%%%%%%%%
192             //   gmu
193             //%%%%%%%%%%%%%%%%%%%%
194             gmu_nr = new AGDenseMatrix(K,ndim_nr);
195             for ( int k = 0; k < K; k++){
196                 Vector tmp = new DenseVector(ndim_nr).zero();
197                 for ( int n = 0; n < ndata; n++){
198                     tmp.add(q.get(n,k),data_nr.sliceRow(n));
199                 }
200                 gmu_nr.setRow(k, tmp.scale(1.0/sum(q.sliceCol(k))));
201             }
202             mu.append( gmu_nr.asVector());
203             //%%%%%%%%%%%%%%%%%%
204             // gcv
205             //%%%%%%%%%%%%%%%%%%
206             for ( int k = 0; k < K; k++){
207                 switch(cv_type) {
208 
209 
210                 case free:
211                     Matrix gcvt = zeros(ndim_nr, ndim_nr);
212                     for ( int n = 0; n < ndata; n++){
213                         Vector tmp = sub(gmu_nr.sliceRow(k),data_nr.sliceRow(n));
214 
215                         gcvt.add(q.get(n,k),vprod(tmp));
216                     }
217                     gcv_nr[k] = (Matrix) gcvt.scale(1.0 / sum(add(q.sliceCol(k),eps)));
218 
219                     cv.append(gcv_nr[k].asVector());
220                     break;
221                 case diagonal:
222                     for ( int i = 0; i < ndim_nr; i++){
223                         double tmp = 0.0;
224                         for ( int n = 0; n < ndata; n++){
225                             tmp = tmp + q.get(n,k)*pow(data_nr.get(n,i)-gmu_nr.get(k,i),2.0);
226                         }
227                         gcv_nr_d.set(k,i, tmp/ sum(add(q.sliceCol(k),eps)));
228                     }
229                     cv.append(gcv_nr_d.sliceRow(k));
230                     break;
231                 case common:
232                     double tmpt = 0.0;
233                     for ( int n = 0; n < ndata; n++){
234 
235                         Vector tmp = sub(data_nr.sliceRow(n), gmu_nr.sliceRow(k));
236                         tmpt=tmpt+q.get(n,k)*tmp.dot(tmp);
237                     }
238                     tmpt = tmpt/(ndim_nr*sum(add(q.sliceCol(k),eps)));
239                     if(tmpt < eps){ 
240                         tmpt = eps;
241                     }
242                     cv.append(tmpt);
243 
244                 }
245             }
246         }
247 
248         //--------------------------------------------------------------------------
249         // the parameters for the binary data
250         //--------------------------------------------------------------------------
251 
252         if(ndim_bin != 0){ //
253             //     [ndata bin_dim] = size(bin_data);
254             int  bin_dim = data_bin.numColumns();
255             Matrix bp = new AGDenseMatrix(K,bin_dim);
256             ndata = data_bin.numRows();
257             for ( int k = 0; k < K; k++){
258                 for ( int j = 0; j < ndim_bin; j++){
259                     //             trialno = 1;
260                     //             bp(k,j) = sum((q.sliceCol(k).*data_bin(:,j)))/(trialno*sum(q.sliceCol(k)+eps));
261                     bp.set(k,j, sum(times(q.sliceCol(k),data_bin.sliceCol(j)))/(sum(add(q.sliceCol(k),eps))));
262                 }
263             }    
264             mu.append(bp.asVector());
265         }
266 
267 
268         if(ndim_mul != 0){ //
269             int dim_mul = data_mul.numColumns();
270             ndata = data_mul.numRows();
271             Matrix mp = new AGDenseMatrix(K,dim_mul);
272             for ( int k = 0; k < K; k++){
273                 for ( int i = 0; i < dim_mul; i++){
274                     mp.set(k,i, sum(times(q.sliceCol(k),data_mul.sliceCol(i))));
275                 }
276                 mp.setRow(k, mp.sliceRow(k).scale(1.0/sum(mp.sliceRow(k))));
277             }
278             mu.append(mp.asVector());
279         }
280 
281         if(ndim_int != 0){ //
282             int dim_int = data_int.numColumns();
283             Matrix ip = new AGDenseMatrix(K,dim_int);
284             for ( int k = 0; k < K; k++){
285                 for ( int i = 0; i < dim_int; i++){
286                     ip.set(k,i , sum(times(q.sliceCol(k),data_int.sliceCol(i)))/sum(q.sliceCol(k)));
287                 }
288             }
289             mu.append(ip.asVector());
290         }
291 
292         //%%%%%%%%%%%%%%%%%
293         // priors
294         //%%%%%%%%%%%%%%%%%
295         Vector p = mean(q, 1);
296 
297         
298         return new Retval(mu, cv, lmu, lcv, p);
299 
300     }
301 
302 }
303 
304 
305 /*
306  * $Log: ClusterMStep.java,v $
307  * Revision 1.3  2010/01/11 21:22:46  pah
308  * reasonable numerical stability and fidelity to MATLAB results achieved
309  *
310  * Revision 1.2  2010/01/05 21:27:13  pah
311  * basic clustering translation complete
312  *
313  * Revision 1.1  2009/09/22 07:04:16  pah
314  * daily checkin
315  *
316  */