xref: /aosp_15_r20/external/rappor/analysis/cpp/fast_em.cc (revision 2abb31345f6c95944768b5222a9a5ed3fc68cc00)
1 // Copyright 2015 Google Inc. All rights reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <assert.h>
16 #include <stdarg.h>  // va_list, etc.
17 #include <stdio.h>  // fread()
18 #include <stdlib.h>  // exit()
19 #include <stdint.h>  // uint16_t
20 #include <string.h>  // strcmp()
21 #include <cmath>  // std::abs operates on doubles
22 #include <cstdlib>  // strtol
23 #include <vector>
24 
25 using std::vector;
26 
27 // Log messages to stdout.
log(const char * fmt,...)28 void log(const char* fmt, ...) {
29   va_list args;
30   va_start(args, fmt);
31   vprintf(fmt, args);
32   va_end(args);
33   printf("\n");
34 }
35 
36 const int kTagLen = 4;  // 4 byte tags in the file format
37 
ExpectTag(FILE * f,const char * tag)38 bool ExpectTag(FILE* f, const char* tag) {
39   char buf[kTagLen];
40 
41   if (fread(buf, sizeof buf[0], kTagLen, f) != kTagLen) {
42     return false;
43   }
44   if (strcmp(buf, tag) != 0) {
45     log("Error: expected '%s'", tag);
46     return false;
47   }
48   return true;
49 }
50 
ReadListOfMatrices(FILE * f,uint32_t * num_entries_out,uint32_t * entry_size_out,vector<double> * v_out)51 static bool ReadListOfMatrices(
52     FILE* f, uint32_t* num_entries_out, uint32_t* entry_size_out,
53     vector<double>* v_out) {
54   if (!ExpectTag(f, "ne ")) {
55     return false;
56   }
57 
58   // R integers are serialized as uint32_t
59   uint32_t num_entries;
60   if (fread(&num_entries, sizeof num_entries, 1, f) != 1) {
61     return false;
62   }
63 
64   log("num entries: %d", num_entries);
65 
66   if (!ExpectTag(f, "es ")) {
67     return false;
68   }
69 
70   uint32_t entry_size;
71   if (fread(&entry_size, sizeof entry_size, 1, f) != 1) {
72     return false;
73   }
74   log("entry_size: %d", entry_size);
75 
76   if (!ExpectTag(f, "dat")) {
77     return false;
78   }
79 
80   // Now read dynamic data
81   size_t vec_length = num_entries * entry_size;
82 
83   vector<double>& v = *v_out;
84   v.resize(vec_length);
85 
86   if (fread(&v[0], sizeof v[0], vec_length, f) != vec_length) {
87     return false;
88   }
89 
90   // Print out head for sanity
91   size_t n = 20;
92   for (size_t i = 0; i < n && i < v.size(); ++i) {
93     log("%d: %f", i, v[i]);
94   }
95 
96   *num_entries_out = num_entries;
97   *entry_size_out = entry_size;
98 
99   return true;
100 }
101 
PrintEntryVector(const vector<double> & cond_prob,size_t m,size_t entry_size)102 void PrintEntryVector(const vector<double>& cond_prob, size_t m,
103                       size_t entry_size) {
104   size_t c_base = m * entry_size;
105   log("cond_prob[m = %d] = ", m);
106   for (size_t i = 0; i < entry_size; ++i) {
107     printf("%e ", cond_prob[c_base + i]);
108   }
109   printf("\n");
110 }
111 
PrintPij(const vector<double> & pij)112 void PrintPij(const vector<double>& pij) {
113   double sum = 0.0;
114   printf("PIJ:\n");
115   for (size_t i = 0; i < pij.size(); ++i) {
116     printf("%f ", pij[i]);
117     sum += pij[i];
118   }
119   printf("\n");
120   printf("SUM: %f\n", sum);  // sum is 1.0 after normalization
121   printf("\n");
122 }
123 
124 // EM algorithm to iteratively estimate parameters.
125 
ExpectationMaximization(uint32_t num_entries,uint32_t entry_size,const vector<double> & cond_prob,int max_em_iters,double epsilon,vector<double> * pij_out)126 static int ExpectationMaximization(
127     uint32_t num_entries, uint32_t entry_size, const vector<double>& cond_prob,
128     int max_em_iters, double epsilon, vector<double>* pij_out) {
129   // Start out with uniform distribution.
130   vector<double> pij(entry_size, 0.0);
131   double init = 1.0 / entry_size;
132   for (size_t i = 0; i < pij.size(); ++i) {
133     pij[i] = init;
134   }
135   log("Initialized %d entries with %f", pij.size(), init);
136 
137   vector<double> prev_pij(entry_size, 0.0);  // pij on previous iteration
138 
139   log("Starting up to %d EM iterations", max_em_iters);
140 
141   int em_iter = 0;  // visible after loop
142   for (; em_iter < max_em_iters; ++em_iter) {
143     //
144     // lapply() step.
145     //
146 
147     // Computed below as a function of old Pij and conditional probability for
148     // each report.
149     vector<double> new_pij(entry_size, 0.0);
150 
151     // m is the matrix index, giving the conditional probability matrix for a
152     // single report.
153     for (size_t m = 0; m < num_entries; ++m) {
154       vector<double> z(entry_size, 0.0);
155 
156       double sum_z = 0.0;
157 
158       // base index for the matrix corresponding to a report.
159       size_t c_base = m * entry_size;
160 
161       for (size_t i = 0; i < entry_size; ++i) {  // multiply and running sum
162         size_t c_index = c_base + i;
163         z[i] = cond_prob[c_index] * pij[i];
164         sum_z += z[i];
165       }
166 
167       // Normalize and Reduce("+", wcp) step.  These two steps are combined for
168       // memory locality.
169       for (size_t i = 0; i < entry_size; ++i) {
170         new_pij[i] += z[i] / sum_z;
171       }
172     }
173 
174     // Divide outside the loop
175     for (size_t i = 0; i < entry_size; ++i) {
176       new_pij[i] /= num_entries;
177     }
178 
179     //PrintPij(new_pij);
180 
181     //
182     // Check for termination
183     //
184     double max_dif = 0.0;
185     for (size_t i = 0; i < entry_size; ++i) {
186       double dif = std::abs(new_pij[i] - pij[i]);
187       if (dif > max_dif) {
188         max_dif = dif;
189       }
190     }
191 
192     pij = new_pij;  // copy
193 
194     log("fast EM iteration %d, dif = %e", em_iter, max_dif);
195 
196     if (max_dif < epsilon) {
197       log("Early EM termination: %e < %e", max_dif, epsilon);
198       break;
199     }
200   }
201 
202   *pij_out = pij;
203   // If we reached iteration index 10, then there were 10 iterations: the last
204   // one terminated the loop.
205   return em_iter;
206 }
207 
WriteTag(const char * tag,FILE * f_out)208 bool WriteTag(const char* tag, FILE* f_out) {
209   assert(strlen(tag) == 3);  // write 3 byte tags with NUL byte
210   return fwrite(tag, 1, 4, f_out) == 4;
211 }
212 
213 // Write the probabilities as a flat list of doubles.  The caller knows what
214 // the dimensions are.
WriteResult(const vector<double> & pij,uint32_t num_em_iters,FILE * f_out)215 bool WriteResult(const vector<double>& pij, uint32_t num_em_iters,
216                  FILE* f_out) {
217   if (!WriteTag("emi", f_out)) {
218     return false;
219   }
220   if (fwrite(&num_em_iters, sizeof num_em_iters, 1, f_out) != 1) {
221     return false;
222   }
223 
224   if (!WriteTag("pij", f_out)) {
225     return false;
226   }
227   size_t n = pij.size();
228   if (fwrite(&pij[0], sizeof pij[0], n, f_out) != n) {
229     return false;
230   }
231   return true;
232 }
233 
234 // Like atoi, but with basic (not exhaustive) error checking.
StringToInt(const char * s,int * result)235 bool StringToInt(const char* s, int* result) {
236   bool ok = true;
237   char* end;  // mutated by strtol
238 
239   *result = strtol(s, &end, 10);  // base 10
240   // If strol didn't consume any characters, it failed.
241   if (end == s) {
242     ok = false;
243   }
244   return ok;
245 }
246 
main(int argc,char ** argv)247 int main(int argc, char **argv) {
248   if (argc < 4) {
249     log("Usage: read_numeric INPUT OUTPUT max_em_iters");
250     return 1;
251   }
252 
253   char* in_filename = argv[1];
254   char* out_filename = argv[2];
255 
256   int max_em_iters;
257   if (!StringToInt(argv[3], &max_em_iters)) {
258     log("Error parsing max_em_iters");
259     return 1;
260   }
261 
262   FILE* f = fopen(in_filename, "rb");
263   if (f == NULL) {
264     return 1;
265   }
266 
267   // Try opening first so we don't do a long computation and then fail.
268   FILE* f_out = fopen(out_filename, "wb");
269   if (f_out == NULL) {
270     return 1;
271   }
272 
273   uint32_t num_entries;
274   uint32_t entry_size;
275   vector<double> cond_prob;
276   if (!ReadListOfMatrices(f, &num_entries, &entry_size, &cond_prob)) {
277     log("Error reading list of matrices");
278     return 1;
279   }
280 
281   fclose(f);
282 
283   // Sanity check
284   double debug_sum = 0.0;
285   for (size_t m = 0; m < num_entries; ++m) {
286     // base index for the matrix corresponding to a report.
287     size_t c_base = m * entry_size;
288     for (size_t i = 0; i < entry_size; ++i) {  // multiply and running sum
289       debug_sum += cond_prob[c_base + i];
290     }
291   }
292   log("Debug sum: %f", debug_sum);
293 
294   double epsilon = 1e-6;
295   log("epsilon: %f", epsilon);
296 
297   vector<double> pij(entry_size);
298   int num_em_iters = ExpectationMaximization(
299       num_entries, entry_size, cond_prob, max_em_iters, epsilon, &pij);
300 
301   if (!WriteResult(pij, num_em_iters, f_out)) {
302     log("Error writing result matrix");
303     return 1;
304   }
305   fclose(f_out);
306 
307   log("fast EM done");
308   return 0;
309 }
310