jormungand.net » projects » misc
graphicsemrandommtgarden

EM Clustering Algorithm

A word of caution

This web page shows up in search results for "em clustering" at a rank far better than my expertise in the matter justifies; I only wrote this for fun and to help understand it myself. The wikipedia description of em is far more authoritative.

Overview

Suppose you have a number of samples drawn from a distribution which can be approximated by (or in this case, is, since the data is made up) a mixture of gaussian distributions and you wish to estimate the parameters of each gaussian and assign each datum to a particular one. The EM, or Expectation Maximization algorithm provides a framework for doing so.


ex8r.png

8 bivariate gaussian distributions with random parameters have been sampled 1000 times. The EM algorithm is used to classify each point into the most likely gaussian (shown as point color) and estimate the parameters of each distribution (shown as ellipses). This is generally more interesting in higher dimensions.


ex8.png

Expectation-maximization, as expected, works in two alternating steps. Expectation refers to computing the probability that each datum is a member of each class; maximization refers to altering the parameters of each class to maximize those probabilities. Eventually it converges, though not necessarily correctly.

New: Javascript Version

cluster demo







This is an interactive demo of the 2d k-means and EM (2d gaussian mixture) clustering algorithms.



C++ Code

See em.tar.gz for the complete code and data files. The interesting stuff is printed below.

Disclaimer: this is messy, buggy, and incomplete, quite possibly downright wrong, and has all sorts of compile warnings. If you are passing this off as your own CS homework, you ought to fix at least the warnings...


  1 #include "M.H"
  2 #include <stdio.h>
  3 #include <fcntl.h>
  4 #include <vector>
  5 
  6 int main(int argc, char **argv){
  7 
  8   srand(time(0));
  9   
 10   vector<V2> data;
 11 
 12   if(argc < 4){
 13     fprintf(stderr,"usage: %s <#classes> <#iterations> <datafile>\n", argv[0]);
 14     return -1;
 15   }
 16   
 17   int nclasses = atoi(argv[1]);
 18   int niteration = atoi(argv[2]);
 19 
 20   if(!nclasses){
 21     fprintf(stderr, "classes should be nonzero\n");
 22     return -1;
 23   }
 24   
 25   char * fname = argv[3];
 26   FILE * file = fopen(fname, "r");
 27   if(!file){
 28     fprintf(stderr, "could not open '%s'\n", fname);
 29     return -1;
 30   }
 31 
 32   char buffer[80];
 33   while(fgets(buffer, sizeof(buffer), file)){
 34     V2 p;
 35     int r = sscanf(buffer, "%lf  %lf", &p.x, &p.y);
 36     if(r == 2){
 37       data.push_back(p);
 38     }else{
 39       fprintf(stderr, "sscanf returned %d on '%s'\n", buffer);
 40     }
 41   }
 42   fprintf(stderr, "loaded %d datums\n", data.size());
 43 
 44   fclose(file);
 45 
 46   vector<M23> classes(nclasses);
 47 
 48   for(unsigned int cls=0; cls<classes.size(); ++cls){
 49     classes[cls].a = 5*(rand()/(double)RAND_MAX);
 50     classes[cls].b = 0;
 51     classes[cls].c = 30*(rand()/(double)RAND_MAX);
 52     classes[cls].d = 0;
 53     classes[cls].e = 5*(rand()/(double)RAND_MAX);
 54     classes[cls].f = 30*(rand()/(double)RAND_MAX);
 55   }
 56 
 57   vector<vector<double> > prob_cls;
 58   for(unsigned int i=0; i<data.size(); ++i){
 59     prob_cls.push_back(vector<double>(classes.size()));
 60   }
 61 
 62   int iteration = 0;
 63   while(1){
 64     fprintf(stderr, "iteration %d ...\n", iteration++);
 65 
 66     // compute probability of each datum being in each class
 67     for(unsigned int cls = 0; cls < classes.size(); ++cls){
 68       V2 mean = classes[cls].zv();
 69       M cov(2,2);
 70       cov(0,0) = classes[cls].a;
 71       cov(0,1) = classes[cls].b;
 72       cov(1,0) = classes[cls].d;
 73       cov(1,1) = classes[cls].e;
 74 
 75       M icov = inv(cov);
 76       double det = classes[cls].a * classes[cls].e - classes[cls].b * classes[cls].d;
 77       double p2 = 1.0 / (2*M_PI*sqrt(det));
 78       
 79       for(unsigned int inst = 0; inst < data.size(); ++inst){
 80         V2 x_minus_u = data[inst] - mean;
 81         double p1 = exp( -0.5 * ( tr((M)(x_minus_u)) * icov * ((M)x_minus_u)) );
 82         double p = p1*p2;
 83         prob_cls[inst][cls] = p;
 84       }
 85     }
 86 
 87     for(unsigned int inst = 0; inst < data.size(); ++inst){
 88       double s = 0;
 89       int b;
 90       double bv = 0;
 91       for(unsigned int cls = 0; cls < classes.size(); ++cls){
 92         s += prob_cls[inst][cls];
 93         if(prob_cls[inst][cls] > bv){
 94           bv = prob_cls[inst][cls];
 95           b = cls;
 96         }
 97       }
 98       for(unsigned int cls = 0; cls < classes.size(); ++cls){
 99         prob_cls[inst][cls] /= s;
100       }
101     }
102 
103     // compute mean, covariance statistics for each class
104 
105     for(unsigned int cls = 0; cls < classes.size(); ++cls){
106       V2 mean;
107       double q = 0;
108       for(unsigned int inst = 0; inst < data.size(); ++inst){
109         mean = mean + prob_cls[inst][cls] * data[inst];
110         q += prob_cls[inst][cls];
111       }
112       mean = mean * (1/q);
113 
114       double xx=0, yy=0, xy=0;;
115       
116       for(unsigned int inst = 0; inst < data.size(); ++inst){
117         double dx = data[inst].x-mean.x;
118         double dy = data[inst].y-mean.y;
119         xx += dx*dx * prob_cls[inst][cls];
120         yy += dy*dy * prob_cls[inst][cls];
121         xy += dx*dy * prob_cls[inst][cls];
122       }
123       xx /= q;
124       yy /= q;
125       xy /= q;
126 
127       printf("class %d new parameters: mean=(%.3f,%.3f) xx=%.3f yy=%.3f xy=%.3f\n",
128           cls, mean.x, mean.y, xx, yy, xy);
129 
130       classes[cls].a = xx;
131       classes[cls].b = xy;
132       classes[cls].d = xy;
133       classes[cls].e = yy;
134       classes[cls].x = mean.x;
135       classes[cls].y = mean.y;
136     }
137         
138 
139     if(iteration == niteration) break;
140     
141   }
142 
143   printf("\nplot");
144 
145   vector<FILE*> outputs(classes.size());
146   for(unsigned int cls=0; cls<outputs.size(); ++cls){
147     char buffer[50];
148     sprintf(buffer, "cls%02d.txt", cls);
149     outputs[cls] = fopen(buffer, "w");
150     printf(" '%s',", buffer);
151   }
152   
153 
154   for(unsigned int inst = 0; inst < data.size(); ++inst){
155     double s = 0;
156     int b;
157     double bv = 0;
158     double r = rand()/(double)RAND_MAX;
159     for(unsigned int cls = 0; cls < classes.size(); ++cls){
160       if(r < prob_cls[inst][cls]){
161         b = cls;
162         break;
163       }else{
164         r += prob_cls[inst][cls];
165       }
166     }
167     fprintf(outputs[b], "%f\t%f\n", data[inst].x, data[inst].y);
168   }
169 
170   for(unsigned int cls=0; cls<outputs.size(); ++cls){
171     fclose(outputs[cls]);
172   }
173 
174   for(unsigned int cls=0; cls<outputs.size(); ++cls){
175     char buffer[50];
176     sprintf(buffer, "cls%02del.txt", cls);
177     FILE * out = fopen(buffer, "w");
178 
179     V2 mean = classes[cls].zv();
180     M cov(2,2);
181     cov(0,0) = classes[cls].a;
182     cov(0,1) = classes[cls].b;
183     cov(1,0) = classes[cls].d;
184     cov(1,1) = classes[cls].e;
185     M icov = inv(cov);
186     
187     double det = classes[cls].a * classes[cls].e - classes[cls].b * classes[cls].d;
188 
189     for(double th=0; th<2*M_PI; th+=M_PI/128){
190         V2 v(cos(th),sin(th));
191 
192         double p = tr((M)(v)) * icov * ((M)v);
193         v = v * (3/sqrt(p));
194 
195         fprintf(out, "%f\t%f\n", v.x+mean.x, v.y+mean.y);
196     }
197     fclose(out);
198     printf(" '%s' with lines,", buffer);
199   }
200   printf("\n\n");
201 
202   return 0;
203 }



© 2000-now
chris@jormungand.net