1
2
3
4
5
6
7
8
9
10
11
12
13 package org.astrogrid.matrix;
14
15 import static org.astrogrid.matrix.MatrixUtils.*;
16
17 import java.util.SortedMap;
18 import java.util.TreeMap;
19
20 import org.apache.commons.math.special.Gamma;
21 import org.netlib.blas.BLAS;
22 import org.netlib.blas.Dtrsv;
23 import org.netlib.lapack.LAPACK;
24
25 import no.uib.cipr.matrix.AGDenseMatrix;
26 import no.uib.cipr.matrix.DenseCholesky;
27 import no.uib.cipr.matrix.DenseMatrix;
28 import no.uib.cipr.matrix.DenseVector;
29 import no.uib.cipr.matrix.MatrixEntry;
30 import no.uib.cipr.matrix.Vector;
31
32 import static java.lang.Math.*;
33
34
35
36
37
38
39
40 public class Algorithms {
41
42 public static final double eps= 2.2204e-16 ;
43
44 private static final org.apache.commons.logging.Log logger = org.apache.commons.logging.LogFactory
45 .getLog(Algorithms.class);
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83 public static void kmeans(Matrix centres, Matrix data, int niters, double absprec, double errprec, boolean fromData, boolean verbose) {
84
85 int ndata = data.numRows(), data_dim = data.numColumns();
86 int ncentres = centres.numRows(), dim= centres.numColumns();
87 double sumsq;
88
89 if (dim != data_dim)
90 {
91 logger.error("Data dimension does not match dimension of centres");
92 return;
93 }
94
95 if (ncentres > ndata){
96 logger.error("More centres than data");
97 return;
98 }
99
100
101 int store = 1;
102
103 Matrix errlog = (Matrix) new AGDenseMatrix(1, niters).zero();
104
105
106
107 if (fromData){
108
109 int[] perm = randperm(ndata-1);
110
111
112 for (int i = 0; i < ncentres; i++) {
113 for(int j = 0; j < data_dim; j++)
114 {
115 centres.set(i, j, data.get(perm[i], j));
116 }
117 }
118 }
119
120
121 Matrix old_centres = new AGDenseMatrix(ncentres, dim);
122 double e = 0, old_e = 0;
123
124 for(int n = 0; n < niters; n++){
125
126
127 old_centres.set(centres);
128
129
130 Matrix d2 = dist2(data, centres);
131
132 MvRet mv = minvals(d2);
133
134
135
136 double[][] sums = new double[data_dim][ncentres];
137 int[] num_points = new int[ncentres];
138 e = 0;
139 for(int i = 0; i < ndata; i++)
140 {
141 for (int j = 0; j < data_dim; j++){
142 sums[j][mv.index[i]] += data.get(i, j);
143 }
144 num_points[mv.index[i]]++;
145 e += mv.minvals[i];
146 }
147 for (int j = 0; j < ncentres; j++){
148 if (num_points[j] > 0){
149 for (int i=0; i < data_dim; i++){
150 centres.set(j,i, sums[i][j]/num_points[j]);
151 }
152 }
153 }
154
155
156
157
158
159 if (verbose){
160 System.out.printf( "Cycle %4d Error %11.6f\n", n, e);
161 }
162
163 if ( n > 1 ){
164
165 if (old_centres.add(-1.0, centres).norm(Matrix.Norm.Maxvalue) < absprec &&
166 Math.abs(old_e - e) < errprec){
167 sumsq = e;
168 return;
169 }
170 }
171 old_e = e;
172 }
173
174
175
176 sumsq = e;
177 if (verbose){
178 logger.warn("Warning: Maximum number of iterations has been exceeded");
179 }
180
181
182
183 }
184
185
186 public static Matrix centre_kmeans(Matrix data, int nclus, int ndim){
187
188
189 Matrix centres= (Matrix) new AGDenseMatrix(nclus, ndim).zero();
190 kmeans(centres, data, 15, 0.001, 0.001, true, true);
191 return centres;
192
193 }
194
195 private static class MvRet {
196 int[] index;
197 double[] minvals;
198 }
199 private static MvRet minvals(Matrix d) {
200
201 MvRet retval = new MvRet();
202 retval.index = new int[d.numRows()];
203 retval.minvals = new double[d.numRows()];
204 for (int i = 0; i < d.numRows(); i++) {
205 int icolmin = 0;
206 retval.minvals[i] = d.get(i, 0);
207 for (int j = 1; j < d.numColumns(); j++){
208 double val = d.get(i, j);
209 if(retval.minvals[i] > val){
210 retval.minvals[i] = val;
211 icolmin = j;
212 }
213
214 }
215 retval.index[i] = icolmin;
216
217 }
218 return retval;
219 }
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237 static public Matrix dist2(Matrix x, Matrix c) {
238 int ndata= x.numRows(), dimx = x.numColumns() ;
239 int ncentres = c.numRows(), dimc = c.numColumns();
240 assert dimx == dimc:
241 "Data dimension does not match dimension of centres";
242 Matrix y = AGDenseMatrix.repeatColumn(x.pow(2).sum(2),ncentres);
243 Matrix d = AGDenseMatrix.repeatRow(c.pow(2).sum(2), ndata);
244 Matrix result = new AGDenseMatrix(ndata,ncentres);
245
246 x.transBmult(-2.0, c, result);
247 result.add(y.add(d));
248
249
250 for (MatrixEntry e : result) {
251 if(e.get() < 0){
252 e.set(0.0);
253 }
254 }
255 return result;
256 }
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290 public static int[] randperm(int n){
291 SortedMap<Double, Integer> sortmap = new TreeMap<Double, Integer>();
292 for (int i = 0; i < n; i++) {
293 sortmap.put(Math.random(), i+1);
294 }
295
296 int[] retval = new int[n];
297 int i=0;
298 for (Integer ii : sortmap.values()) {
299 retval[i++] = ii;
300 }
301 return retval;
302
303 }
304
305
306 static public Matrix dist3_common(Matrix x, Matrix centres, Vector covars){
307
308 int ndata = x.numRows(), ndim = x.numColumns();
309 int K = centres.numRows(), T =centres.numColumns();
310 Matrix n2;
311
312 n2 = divide(dist2(x, centres) , repmatt(covars, ndata, 1));
313 return n2;
314
315 }
316
317 static public Matrix dist3_diag(Matrix x, Matrix centres, Matrix covars){
318
319 int ndata = x.numRows(), ndim = x.numColumns();
320 int K = centres.numRows(), T =centres.numColumns();
321 Matrix n2 = new AGDenseMatrix(ndata, K);
322 for (int i = 0; i < K ;i++){
323 Matrix diffs = sub(x , mult(ones(ndata, 1) , centres.sliceRowM(i)));
324 Vector sums = sum(divide(pow(diffs,2.0),(mult(ones(ndata, 1) ,
325 covars.sliceRowM(i)))), 2);
326 for (int j = 0; j < ndata ; j++)
327 n2.set(j, i, sums.get(j) );
328 }
329
330
331 return n2;
332 }
333 static public Matrix dist3_free(Matrix x, Matrix centres, Matrix[] covars){
334
335 int ndata = x.numRows(), ndim = x.numColumns();
336 int K = centres.numRows(), T =centres.numColumns();
337 Matrix n2 = new AGDenseMatrix(ndata, K);
338
339 for (int i = 0 ;i<K; i++){
340 AGDenseMatrix diffs = (AGDenseMatrix) sub(x , mult(ones(ndata, 1) , centres.sliceRowM(i)));
341
342
343 DenseCholesky c = DenseCholesky.factorize(covars[i]);
344
345 no.uib.cipr.matrix.Matrix temp = new DenseMatrix(ndim, ndata);
346 transpose(c.getU()).solve(transpose(diffs), temp);
347
348 n2.setColumn(i,sum(times(temp,temp),1));
349 }
350
351
352 return n2;
353 }
354
355
356 public static double gamma(double s){
357 return Math.exp(Gamma.logGamma(s));
358 }
359
360 public static Vector mean(Matrix m, int i){
361 Vector retval = sum(m, i);
362 if(i == 1){
363 return retval.scale(1.0/m.numRows());
364 }else {
365 return retval.scale(1.0/m.numColumns());
366 }
367 }
368
369 public static int rem(int a, int b){
370 return a % b;
371 }
372
373
374 public static Vector multinorm(Matrix x, Vector m, Matrix covar ){
375
376
377
378
379 int dim = x.numRows(), npoints=x.numColumns();
380 covar = add(covar, eye(dim,Double.MIN_VALUE));
381 double dd = det(covar);
382 Matrix in = inv(covar);
383
384
385 double ff = pow((2*PI),(-dim/2.0))*pow((dd),(-0.5));
386 Matrix centered = sub(x,repmat(m, 1, npoints));
387 Vector y;
388 if (dim != 1)
389 y = exp(sum(times(centered,(mult(in,centered)))).scale(-0.5)).scale(ff);
390 else
391 y = exp(mult(in,pow(centered,2)).asVector().scale(-0.5) ).scale(ff);
392 return y;
393
394 }
395 public static Vector t_multinorm(Matrix x, Vector m, Matrix covar, double v ){
396
397
398
399
400 int dim = x.numRows(), npoints=x.numColumns();
401 covar = add(covar, eye(dim,Double.MIN_VALUE));
402 double dd = det(covar);
403 Matrix in = inv(covar);
404 double ff = gamma(0.5*(v+dim))/(pow(PI*v,0.5*dim) * gamma(v/2) * sqrt(dd));
405 Matrix centered = sub(x,repmat(m, 1, npoints));
406 Vector y = new DenseVector(npoints);
407 if (dim != 1){
408 for (int j = 0; j <npoints; j++){
409 double d = multAt(centered.sliceCol(j),mult(in,centered.sliceCol(j))).asScalar();
410 y.set(j, ff / ( pow(1 + d/v,0.5*(v + dim)) ));
411 }
412
413
414 }
415 else
416 y = pow(add(1.0 , pow(mult(in,centered).asVector(),2).scale(1.0/v)) , (-0.5*(v+dim)) ).scale(ff);
417
418 return y;
419 }
420
421
422 }
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439