xref: /aosp_15_r20/external/rappor/analysis/cpp/fast_em.cc (revision 2abb31345f6c95944768b5222a9a5ed3fc68cc00)
1*2abb3134SXin Li // Copyright 2015 Google Inc. All rights reserved.
2*2abb3134SXin Li //
3*2abb3134SXin Li // Licensed under the Apache License, Version 2.0 (the "License");
4*2abb3134SXin Li // you may not use this file except in compliance with the License.
5*2abb3134SXin Li // You may obtain a copy of the License at
6*2abb3134SXin Li //
7*2abb3134SXin Li //     http://www.apache.org/licenses/LICENSE-2.0
8*2abb3134SXin Li //
9*2abb3134SXin Li // Unless required by applicable law or agreed to in writing, software
10*2abb3134SXin Li // distributed under the License is distributed on an "AS IS" BASIS,
11*2abb3134SXin Li // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*2abb3134SXin Li // See the License for the specific language governing permissions and
13*2abb3134SXin Li // limitations under the License.
14*2abb3134SXin Li 
15*2abb3134SXin Li #include <assert.h>
16*2abb3134SXin Li #include <stdarg.h>  // va_list, etc.
17*2abb3134SXin Li #include <stdio.h>  // fread()
18*2abb3134SXin Li #include <stdlib.h>  // exit()
19*2abb3134SXin Li #include <stdint.h>  // uint16_t
20*2abb3134SXin Li #include <string.h>  // strcmp()
21*2abb3134SXin Li #include <cmath>  // std::abs operates on doubles
22*2abb3134SXin Li #include <cstdlib>  // strtol
23*2abb3134SXin Li #include <vector>
24*2abb3134SXin Li 
25*2abb3134SXin Li using std::vector;
26*2abb3134SXin Li 
27*2abb3134SXin Li // Log messages to stdout.
log(const char * fmt,...)28*2abb3134SXin Li void log(const char* fmt, ...) {
29*2abb3134SXin Li   va_list args;
30*2abb3134SXin Li   va_start(args, fmt);
31*2abb3134SXin Li   vprintf(fmt, args);
32*2abb3134SXin Li   va_end(args);
33*2abb3134SXin Li   printf("\n");
34*2abb3134SXin Li }
35*2abb3134SXin Li 
36*2abb3134SXin Li const int kTagLen = 4;  // 4 byte tags in the file format
37*2abb3134SXin Li 
ExpectTag(FILE * f,const char * tag)38*2abb3134SXin Li bool ExpectTag(FILE* f, const char* tag) {
39*2abb3134SXin Li   char buf[kTagLen];
40*2abb3134SXin Li 
41*2abb3134SXin Li   if (fread(buf, sizeof buf[0], kTagLen, f) != kTagLen) {
42*2abb3134SXin Li     return false;
43*2abb3134SXin Li   }
44*2abb3134SXin Li   if (strcmp(buf, tag) != 0) {
45*2abb3134SXin Li     log("Error: expected '%s'", tag);
46*2abb3134SXin Li     return false;
47*2abb3134SXin Li   }
48*2abb3134SXin Li   return true;
49*2abb3134SXin Li }
50*2abb3134SXin Li 
ReadListOfMatrices(FILE * f,uint32_t * num_entries_out,uint32_t * entry_size_out,vector<double> * v_out)51*2abb3134SXin Li static bool ReadListOfMatrices(
52*2abb3134SXin Li     FILE* f, uint32_t* num_entries_out, uint32_t* entry_size_out,
53*2abb3134SXin Li     vector<double>* v_out) {
54*2abb3134SXin Li   if (!ExpectTag(f, "ne ")) {
55*2abb3134SXin Li     return false;
56*2abb3134SXin Li   }
57*2abb3134SXin Li 
58*2abb3134SXin Li   // R integers are serialized as uint32_t
59*2abb3134SXin Li   uint32_t num_entries;
60*2abb3134SXin Li   if (fread(&num_entries, sizeof num_entries, 1, f) != 1) {
61*2abb3134SXin Li     return false;
62*2abb3134SXin Li   }
63*2abb3134SXin Li 
64*2abb3134SXin Li   log("num entries: %d", num_entries);
65*2abb3134SXin Li 
66*2abb3134SXin Li   if (!ExpectTag(f, "es ")) {
67*2abb3134SXin Li     return false;
68*2abb3134SXin Li   }
69*2abb3134SXin Li 
70*2abb3134SXin Li   uint32_t entry_size;
71*2abb3134SXin Li   if (fread(&entry_size, sizeof entry_size, 1, f) != 1) {
72*2abb3134SXin Li     return false;
73*2abb3134SXin Li   }
74*2abb3134SXin Li   log("entry_size: %d", entry_size);
75*2abb3134SXin Li 
76*2abb3134SXin Li   if (!ExpectTag(f, "dat")) {
77*2abb3134SXin Li     return false;
78*2abb3134SXin Li   }
79*2abb3134SXin Li 
80*2abb3134SXin Li   // Now read dynamic data
81*2abb3134SXin Li   size_t vec_length = num_entries * entry_size;
82*2abb3134SXin Li 
83*2abb3134SXin Li   vector<double>& v = *v_out;
84*2abb3134SXin Li   v.resize(vec_length);
85*2abb3134SXin Li 
86*2abb3134SXin Li   if (fread(&v[0], sizeof v[0], vec_length, f) != vec_length) {
87*2abb3134SXin Li     return false;
88*2abb3134SXin Li   }
89*2abb3134SXin Li 
90*2abb3134SXin Li   // Print out head for sanity
91*2abb3134SXin Li   size_t n = 20;
92*2abb3134SXin Li   for (size_t i = 0; i < n && i < v.size(); ++i) {
93*2abb3134SXin Li     log("%d: %f", i, v[i]);
94*2abb3134SXin Li   }
95*2abb3134SXin Li 
96*2abb3134SXin Li   *num_entries_out = num_entries;
97*2abb3134SXin Li   *entry_size_out = entry_size;
98*2abb3134SXin Li 
99*2abb3134SXin Li   return true;
100*2abb3134SXin Li }
101*2abb3134SXin Li 
PrintEntryVector(const vector<double> & cond_prob,size_t m,size_t entry_size)102*2abb3134SXin Li void PrintEntryVector(const vector<double>& cond_prob, size_t m,
103*2abb3134SXin Li                       size_t entry_size) {
104*2abb3134SXin Li   size_t c_base = m * entry_size;
105*2abb3134SXin Li   log("cond_prob[m = %d] = ", m);
106*2abb3134SXin Li   for (size_t i = 0; i < entry_size; ++i) {
107*2abb3134SXin Li     printf("%e ", cond_prob[c_base + i]);
108*2abb3134SXin Li   }
109*2abb3134SXin Li   printf("\n");
110*2abb3134SXin Li }
111*2abb3134SXin Li 
PrintPij(const vector<double> & pij)112*2abb3134SXin Li void PrintPij(const vector<double>& pij) {
113*2abb3134SXin Li   double sum = 0.0;
114*2abb3134SXin Li   printf("PIJ:\n");
115*2abb3134SXin Li   for (size_t i = 0; i < pij.size(); ++i) {
116*2abb3134SXin Li     printf("%f ", pij[i]);
117*2abb3134SXin Li     sum += pij[i];
118*2abb3134SXin Li   }
119*2abb3134SXin Li   printf("\n");
120*2abb3134SXin Li   printf("SUM: %f\n", sum);  // sum is 1.0 after normalization
121*2abb3134SXin Li   printf("\n");
122*2abb3134SXin Li }
123*2abb3134SXin Li 
124*2abb3134SXin Li // EM algorithm to iteratively estimate parameters.
125*2abb3134SXin Li 
ExpectationMaximization(uint32_t num_entries,uint32_t entry_size,const vector<double> & cond_prob,int max_em_iters,double epsilon,vector<double> * pij_out)126*2abb3134SXin Li static int ExpectationMaximization(
127*2abb3134SXin Li     uint32_t num_entries, uint32_t entry_size, const vector<double>& cond_prob,
128*2abb3134SXin Li     int max_em_iters, double epsilon, vector<double>* pij_out) {
129*2abb3134SXin Li   // Start out with uniform distribution.
130*2abb3134SXin Li   vector<double> pij(entry_size, 0.0);
131*2abb3134SXin Li   double init = 1.0 / entry_size;
132*2abb3134SXin Li   for (size_t i = 0; i < pij.size(); ++i) {
133*2abb3134SXin Li     pij[i] = init;
134*2abb3134SXin Li   }
135*2abb3134SXin Li   log("Initialized %d entries with %f", pij.size(), init);
136*2abb3134SXin Li 
137*2abb3134SXin Li   vector<double> prev_pij(entry_size, 0.0);  // pij on previous iteration
138*2abb3134SXin Li 
139*2abb3134SXin Li   log("Starting up to %d EM iterations", max_em_iters);
140*2abb3134SXin Li 
141*2abb3134SXin Li   int em_iter = 0;  // visible after loop
142*2abb3134SXin Li   for (; em_iter < max_em_iters; ++em_iter) {
143*2abb3134SXin Li     //
144*2abb3134SXin Li     // lapply() step.
145*2abb3134SXin Li     //
146*2abb3134SXin Li 
147*2abb3134SXin Li     // Computed below as a function of old Pij and conditional probability for
148*2abb3134SXin Li     // each report.
149*2abb3134SXin Li     vector<double> new_pij(entry_size, 0.0);
150*2abb3134SXin Li 
151*2abb3134SXin Li     // m is the matrix index, giving the conditional probability matrix for a
152*2abb3134SXin Li     // single report.
153*2abb3134SXin Li     for (size_t m = 0; m < num_entries; ++m) {
154*2abb3134SXin Li       vector<double> z(entry_size, 0.0);
155*2abb3134SXin Li 
156*2abb3134SXin Li       double sum_z = 0.0;
157*2abb3134SXin Li 
158*2abb3134SXin Li       // base index for the matrix corresponding to a report.
159*2abb3134SXin Li       size_t c_base = m * entry_size;
160*2abb3134SXin Li 
161*2abb3134SXin Li       for (size_t i = 0; i < entry_size; ++i) {  // multiply and running sum
162*2abb3134SXin Li         size_t c_index = c_base + i;
163*2abb3134SXin Li         z[i] = cond_prob[c_index] * pij[i];
164*2abb3134SXin Li         sum_z += z[i];
165*2abb3134SXin Li       }
166*2abb3134SXin Li 
167*2abb3134SXin Li       // Normalize and Reduce("+", wcp) step.  These two steps are combined for
168*2abb3134SXin Li       // memory locality.
169*2abb3134SXin Li       for (size_t i = 0; i < entry_size; ++i) {
170*2abb3134SXin Li         new_pij[i] += z[i] / sum_z;
171*2abb3134SXin Li       }
172*2abb3134SXin Li     }
173*2abb3134SXin Li 
174*2abb3134SXin Li     // Divide outside the loop
175*2abb3134SXin Li     for (size_t i = 0; i < entry_size; ++i) {
176*2abb3134SXin Li       new_pij[i] /= num_entries;
177*2abb3134SXin Li     }
178*2abb3134SXin Li 
179*2abb3134SXin Li     //PrintPij(new_pij);
180*2abb3134SXin Li 
181*2abb3134SXin Li     //
182*2abb3134SXin Li     // Check for termination
183*2abb3134SXin Li     //
184*2abb3134SXin Li     double max_dif = 0.0;
185*2abb3134SXin Li     for (size_t i = 0; i < entry_size; ++i) {
186*2abb3134SXin Li       double dif = std::abs(new_pij[i] - pij[i]);
187*2abb3134SXin Li       if (dif > max_dif) {
188*2abb3134SXin Li         max_dif = dif;
189*2abb3134SXin Li       }
190*2abb3134SXin Li     }
191*2abb3134SXin Li 
192*2abb3134SXin Li     pij = new_pij;  // copy
193*2abb3134SXin Li 
194*2abb3134SXin Li     log("fast EM iteration %d, dif = %e", em_iter, max_dif);
195*2abb3134SXin Li 
196*2abb3134SXin Li     if (max_dif < epsilon) {
197*2abb3134SXin Li       log("Early EM termination: %e < %e", max_dif, epsilon);
198*2abb3134SXin Li       break;
199*2abb3134SXin Li     }
200*2abb3134SXin Li   }
201*2abb3134SXin Li 
202*2abb3134SXin Li   *pij_out = pij;
203*2abb3134SXin Li   // If we reached iteration index 10, then there were 10 iterations: the last
204*2abb3134SXin Li   // one terminated the loop.
205*2abb3134SXin Li   return em_iter;
206*2abb3134SXin Li }
207*2abb3134SXin Li 
WriteTag(const char * tag,FILE * f_out)208*2abb3134SXin Li bool WriteTag(const char* tag, FILE* f_out) {
209*2abb3134SXin Li   assert(strlen(tag) == 3);  // write 3 byte tags with NUL byte
210*2abb3134SXin Li   return fwrite(tag, 1, 4, f_out) == 4;
211*2abb3134SXin Li }
212*2abb3134SXin Li 
213*2abb3134SXin Li // Write the probabilities as a flat list of doubles.  The caller knows what
214*2abb3134SXin Li // the dimensions are.
WriteResult(const vector<double> & pij,uint32_t num_em_iters,FILE * f_out)215*2abb3134SXin Li bool WriteResult(const vector<double>& pij, uint32_t num_em_iters,
216*2abb3134SXin Li                  FILE* f_out) {
217*2abb3134SXin Li   if (!WriteTag("emi", f_out)) {
218*2abb3134SXin Li     return false;
219*2abb3134SXin Li   }
220*2abb3134SXin Li   if (fwrite(&num_em_iters, sizeof num_em_iters, 1, f_out) != 1) {
221*2abb3134SXin Li     return false;
222*2abb3134SXin Li   }
223*2abb3134SXin Li 
224*2abb3134SXin Li   if (!WriteTag("pij", f_out)) {
225*2abb3134SXin Li     return false;
226*2abb3134SXin Li   }
227*2abb3134SXin Li   size_t n = pij.size();
228*2abb3134SXin Li   if (fwrite(&pij[0], sizeof pij[0], n, f_out) != n) {
229*2abb3134SXin Li     return false;
230*2abb3134SXin Li   }
231*2abb3134SXin Li   return true;
232*2abb3134SXin Li }
233*2abb3134SXin Li 
234*2abb3134SXin Li // Like atoi, but with basic (not exhaustive) error checking.
StringToInt(const char * s,int * result)235*2abb3134SXin Li bool StringToInt(const char* s, int* result) {
236*2abb3134SXin Li   bool ok = true;
237*2abb3134SXin Li   char* end;  // mutated by strtol
238*2abb3134SXin Li 
239*2abb3134SXin Li   *result = strtol(s, &end, 10);  // base 10
240*2abb3134SXin Li   // If strol didn't consume any characters, it failed.
241*2abb3134SXin Li   if (end == s) {
242*2abb3134SXin Li     ok = false;
243*2abb3134SXin Li   }
244*2abb3134SXin Li   return ok;
245*2abb3134SXin Li }
246*2abb3134SXin Li 
main(int argc,char ** argv)247*2abb3134SXin Li int main(int argc, char **argv) {
248*2abb3134SXin Li   if (argc < 4) {
249*2abb3134SXin Li     log("Usage: read_numeric INPUT OUTPUT max_em_iters");
250*2abb3134SXin Li     return 1;
251*2abb3134SXin Li   }
252*2abb3134SXin Li 
253*2abb3134SXin Li   char* in_filename = argv[1];
254*2abb3134SXin Li   char* out_filename = argv[2];
255*2abb3134SXin Li 
256*2abb3134SXin Li   int max_em_iters;
257*2abb3134SXin Li   if (!StringToInt(argv[3], &max_em_iters)) {
258*2abb3134SXin Li     log("Error parsing max_em_iters");
259*2abb3134SXin Li     return 1;
260*2abb3134SXin Li   }
261*2abb3134SXin Li 
262*2abb3134SXin Li   FILE* f = fopen(in_filename, "rb");
263*2abb3134SXin Li   if (f == NULL) {
264*2abb3134SXin Li     return 1;
265*2abb3134SXin Li   }
266*2abb3134SXin Li 
267*2abb3134SXin Li   // Try opening first so we don't do a long computation and then fail.
268*2abb3134SXin Li   FILE* f_out = fopen(out_filename, "wb");
269*2abb3134SXin Li   if (f_out == NULL) {
270*2abb3134SXin Li     return 1;
271*2abb3134SXin Li   }
272*2abb3134SXin Li 
273*2abb3134SXin Li   uint32_t num_entries;
274*2abb3134SXin Li   uint32_t entry_size;
275*2abb3134SXin Li   vector<double> cond_prob;
276*2abb3134SXin Li   if (!ReadListOfMatrices(f, &num_entries, &entry_size, &cond_prob)) {
277*2abb3134SXin Li     log("Error reading list of matrices");
278*2abb3134SXin Li     return 1;
279*2abb3134SXin Li   }
280*2abb3134SXin Li 
281*2abb3134SXin Li   fclose(f);
282*2abb3134SXin Li 
283*2abb3134SXin Li   // Sanity check
284*2abb3134SXin Li   double debug_sum = 0.0;
285*2abb3134SXin Li   for (size_t m = 0; m < num_entries; ++m) {
286*2abb3134SXin Li     // base index for the matrix corresponding to a report.
287*2abb3134SXin Li     size_t c_base = m * entry_size;
288*2abb3134SXin Li     for (size_t i = 0; i < entry_size; ++i) {  // multiply and running sum
289*2abb3134SXin Li       debug_sum += cond_prob[c_base + i];
290*2abb3134SXin Li     }
291*2abb3134SXin Li   }
292*2abb3134SXin Li   log("Debug sum: %f", debug_sum);
293*2abb3134SXin Li 
294*2abb3134SXin Li   double epsilon = 1e-6;
295*2abb3134SXin Li   log("epsilon: %f", epsilon);
296*2abb3134SXin Li 
297*2abb3134SXin Li   vector<double> pij(entry_size);
298*2abb3134SXin Li   int num_em_iters = ExpectationMaximization(
299*2abb3134SXin Li       num_entries, entry_size, cond_prob, max_em_iters, epsilon, &pij);
300*2abb3134SXin Li 
301*2abb3134SXin Li   if (!WriteResult(pij, num_em_iters, f_out)) {
302*2abb3134SXin Li     log("Error writing result matrix");
303*2abb3134SXin Li     return 1;
304*2abb3134SXin Li   }
305*2abb3134SXin Li   fclose(f_out);
306*2abb3134SXin Li 
307*2abb3134SXin Li   log("fast EM done");
308*2abb3134SXin Li   return 0;
309*2abb3134SXin Li }
310