1*2abb3134SXin Li // Copyright 2015 Google Inc. All rights reserved.
2*2abb3134SXin Li //
3*2abb3134SXin Li // Licensed under the Apache License, Version 2.0 (the "License");
4*2abb3134SXin Li // you may not use this file except in compliance with the License.
5*2abb3134SXin Li // You may obtain a copy of the License at
6*2abb3134SXin Li //
7*2abb3134SXin Li // http://www.apache.org/licenses/LICENSE-2.0
8*2abb3134SXin Li //
9*2abb3134SXin Li // Unless required by applicable law or agreed to in writing, software
10*2abb3134SXin Li // distributed under the License is distributed on an "AS IS" BASIS,
11*2abb3134SXin Li // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*2abb3134SXin Li // See the License for the specific language governing permissions and
13*2abb3134SXin Li // limitations under the License.
14*2abb3134SXin Li
15*2abb3134SXin Li #include <assert.h>
16*2abb3134SXin Li #include <stdarg.h> // va_list, etc.
17*2abb3134SXin Li #include <stdio.h> // fread()
18*2abb3134SXin Li #include <stdlib.h> // exit()
19*2abb3134SXin Li #include <stdint.h> // uint16_t
20*2abb3134SXin Li #include <string.h> // strcmp()
21*2abb3134SXin Li #include <cmath> // std::abs operates on doubles
22*2abb3134SXin Li #include <cstdlib> // strtol
23*2abb3134SXin Li #include <vector>
24*2abb3134SXin Li
25*2abb3134SXin Li using std::vector;
26*2abb3134SXin Li
27*2abb3134SXin Li // Log messages to stdout.
log(const char * fmt,...)28*2abb3134SXin Li void log(const char* fmt, ...) {
29*2abb3134SXin Li va_list args;
30*2abb3134SXin Li va_start(args, fmt);
31*2abb3134SXin Li vprintf(fmt, args);
32*2abb3134SXin Li va_end(args);
33*2abb3134SXin Li printf("\n");
34*2abb3134SXin Li }
35*2abb3134SXin Li
36*2abb3134SXin Li const int kTagLen = 4; // 4 byte tags in the file format
37*2abb3134SXin Li
ExpectTag(FILE * f,const char * tag)38*2abb3134SXin Li bool ExpectTag(FILE* f, const char* tag) {
39*2abb3134SXin Li char buf[kTagLen];
40*2abb3134SXin Li
41*2abb3134SXin Li if (fread(buf, sizeof buf[0], kTagLen, f) != kTagLen) {
42*2abb3134SXin Li return false;
43*2abb3134SXin Li }
44*2abb3134SXin Li if (strcmp(buf, tag) != 0) {
45*2abb3134SXin Li log("Error: expected '%s'", tag);
46*2abb3134SXin Li return false;
47*2abb3134SXin Li }
48*2abb3134SXin Li return true;
49*2abb3134SXin Li }
50*2abb3134SXin Li
ReadListOfMatrices(FILE * f,uint32_t * num_entries_out,uint32_t * entry_size_out,vector<double> * v_out)51*2abb3134SXin Li static bool ReadListOfMatrices(
52*2abb3134SXin Li FILE* f, uint32_t* num_entries_out, uint32_t* entry_size_out,
53*2abb3134SXin Li vector<double>* v_out) {
54*2abb3134SXin Li if (!ExpectTag(f, "ne ")) {
55*2abb3134SXin Li return false;
56*2abb3134SXin Li }
57*2abb3134SXin Li
58*2abb3134SXin Li // R integers are serialized as uint32_t
59*2abb3134SXin Li uint32_t num_entries;
60*2abb3134SXin Li if (fread(&num_entries, sizeof num_entries, 1, f) != 1) {
61*2abb3134SXin Li return false;
62*2abb3134SXin Li }
63*2abb3134SXin Li
64*2abb3134SXin Li log("num entries: %d", num_entries);
65*2abb3134SXin Li
66*2abb3134SXin Li if (!ExpectTag(f, "es ")) {
67*2abb3134SXin Li return false;
68*2abb3134SXin Li }
69*2abb3134SXin Li
70*2abb3134SXin Li uint32_t entry_size;
71*2abb3134SXin Li if (fread(&entry_size, sizeof entry_size, 1, f) != 1) {
72*2abb3134SXin Li return false;
73*2abb3134SXin Li }
74*2abb3134SXin Li log("entry_size: %d", entry_size);
75*2abb3134SXin Li
76*2abb3134SXin Li if (!ExpectTag(f, "dat")) {
77*2abb3134SXin Li return false;
78*2abb3134SXin Li }
79*2abb3134SXin Li
80*2abb3134SXin Li // Now read dynamic data
81*2abb3134SXin Li size_t vec_length = num_entries * entry_size;
82*2abb3134SXin Li
83*2abb3134SXin Li vector<double>& v = *v_out;
84*2abb3134SXin Li v.resize(vec_length);
85*2abb3134SXin Li
86*2abb3134SXin Li if (fread(&v[0], sizeof v[0], vec_length, f) != vec_length) {
87*2abb3134SXin Li return false;
88*2abb3134SXin Li }
89*2abb3134SXin Li
90*2abb3134SXin Li // Print out head for sanity
91*2abb3134SXin Li size_t n = 20;
92*2abb3134SXin Li for (size_t i = 0; i < n && i < v.size(); ++i) {
93*2abb3134SXin Li log("%d: %f", i, v[i]);
94*2abb3134SXin Li }
95*2abb3134SXin Li
96*2abb3134SXin Li *num_entries_out = num_entries;
97*2abb3134SXin Li *entry_size_out = entry_size;
98*2abb3134SXin Li
99*2abb3134SXin Li return true;
100*2abb3134SXin Li }
101*2abb3134SXin Li
PrintEntryVector(const vector<double> & cond_prob,size_t m,size_t entry_size)102*2abb3134SXin Li void PrintEntryVector(const vector<double>& cond_prob, size_t m,
103*2abb3134SXin Li size_t entry_size) {
104*2abb3134SXin Li size_t c_base = m * entry_size;
105*2abb3134SXin Li log("cond_prob[m = %d] = ", m);
106*2abb3134SXin Li for (size_t i = 0; i < entry_size; ++i) {
107*2abb3134SXin Li printf("%e ", cond_prob[c_base + i]);
108*2abb3134SXin Li }
109*2abb3134SXin Li printf("\n");
110*2abb3134SXin Li }
111*2abb3134SXin Li
PrintPij(const vector<double> & pij)112*2abb3134SXin Li void PrintPij(const vector<double>& pij) {
113*2abb3134SXin Li double sum = 0.0;
114*2abb3134SXin Li printf("PIJ:\n");
115*2abb3134SXin Li for (size_t i = 0; i < pij.size(); ++i) {
116*2abb3134SXin Li printf("%f ", pij[i]);
117*2abb3134SXin Li sum += pij[i];
118*2abb3134SXin Li }
119*2abb3134SXin Li printf("\n");
120*2abb3134SXin Li printf("SUM: %f\n", sum); // sum is 1.0 after normalization
121*2abb3134SXin Li printf("\n");
122*2abb3134SXin Li }
123*2abb3134SXin Li
124*2abb3134SXin Li // EM algorithm to iteratively estimate parameters.
125*2abb3134SXin Li
ExpectationMaximization(uint32_t num_entries,uint32_t entry_size,const vector<double> & cond_prob,int max_em_iters,double epsilon,vector<double> * pij_out)126*2abb3134SXin Li static int ExpectationMaximization(
127*2abb3134SXin Li uint32_t num_entries, uint32_t entry_size, const vector<double>& cond_prob,
128*2abb3134SXin Li int max_em_iters, double epsilon, vector<double>* pij_out) {
129*2abb3134SXin Li // Start out with uniform distribution.
130*2abb3134SXin Li vector<double> pij(entry_size, 0.0);
131*2abb3134SXin Li double init = 1.0 / entry_size;
132*2abb3134SXin Li for (size_t i = 0; i < pij.size(); ++i) {
133*2abb3134SXin Li pij[i] = init;
134*2abb3134SXin Li }
135*2abb3134SXin Li log("Initialized %d entries with %f", pij.size(), init);
136*2abb3134SXin Li
137*2abb3134SXin Li vector<double> prev_pij(entry_size, 0.0); // pij on previous iteration
138*2abb3134SXin Li
139*2abb3134SXin Li log("Starting up to %d EM iterations", max_em_iters);
140*2abb3134SXin Li
141*2abb3134SXin Li int em_iter = 0; // visible after loop
142*2abb3134SXin Li for (; em_iter < max_em_iters; ++em_iter) {
143*2abb3134SXin Li //
144*2abb3134SXin Li // lapply() step.
145*2abb3134SXin Li //
146*2abb3134SXin Li
147*2abb3134SXin Li // Computed below as a function of old Pij and conditional probability for
148*2abb3134SXin Li // each report.
149*2abb3134SXin Li vector<double> new_pij(entry_size, 0.0);
150*2abb3134SXin Li
151*2abb3134SXin Li // m is the matrix index, giving the conditional probability matrix for a
152*2abb3134SXin Li // single report.
153*2abb3134SXin Li for (size_t m = 0; m < num_entries; ++m) {
154*2abb3134SXin Li vector<double> z(entry_size, 0.0);
155*2abb3134SXin Li
156*2abb3134SXin Li double sum_z = 0.0;
157*2abb3134SXin Li
158*2abb3134SXin Li // base index for the matrix corresponding to a report.
159*2abb3134SXin Li size_t c_base = m * entry_size;
160*2abb3134SXin Li
161*2abb3134SXin Li for (size_t i = 0; i < entry_size; ++i) { // multiply and running sum
162*2abb3134SXin Li size_t c_index = c_base + i;
163*2abb3134SXin Li z[i] = cond_prob[c_index] * pij[i];
164*2abb3134SXin Li sum_z += z[i];
165*2abb3134SXin Li }
166*2abb3134SXin Li
167*2abb3134SXin Li // Normalize and Reduce("+", wcp) step. These two steps are combined for
168*2abb3134SXin Li // memory locality.
169*2abb3134SXin Li for (size_t i = 0; i < entry_size; ++i) {
170*2abb3134SXin Li new_pij[i] += z[i] / sum_z;
171*2abb3134SXin Li }
172*2abb3134SXin Li }
173*2abb3134SXin Li
174*2abb3134SXin Li // Divide outside the loop
175*2abb3134SXin Li for (size_t i = 0; i < entry_size; ++i) {
176*2abb3134SXin Li new_pij[i] /= num_entries;
177*2abb3134SXin Li }
178*2abb3134SXin Li
179*2abb3134SXin Li //PrintPij(new_pij);
180*2abb3134SXin Li
181*2abb3134SXin Li //
182*2abb3134SXin Li // Check for termination
183*2abb3134SXin Li //
184*2abb3134SXin Li double max_dif = 0.0;
185*2abb3134SXin Li for (size_t i = 0; i < entry_size; ++i) {
186*2abb3134SXin Li double dif = std::abs(new_pij[i] - pij[i]);
187*2abb3134SXin Li if (dif > max_dif) {
188*2abb3134SXin Li max_dif = dif;
189*2abb3134SXin Li }
190*2abb3134SXin Li }
191*2abb3134SXin Li
192*2abb3134SXin Li pij = new_pij; // copy
193*2abb3134SXin Li
194*2abb3134SXin Li log("fast EM iteration %d, dif = %e", em_iter, max_dif);
195*2abb3134SXin Li
196*2abb3134SXin Li if (max_dif < epsilon) {
197*2abb3134SXin Li log("Early EM termination: %e < %e", max_dif, epsilon);
198*2abb3134SXin Li break;
199*2abb3134SXin Li }
200*2abb3134SXin Li }
201*2abb3134SXin Li
202*2abb3134SXin Li *pij_out = pij;
203*2abb3134SXin Li // If we reached iteration index 10, then there were 10 iterations: the last
204*2abb3134SXin Li // one terminated the loop.
205*2abb3134SXin Li return em_iter;
206*2abb3134SXin Li }
207*2abb3134SXin Li
WriteTag(const char * tag,FILE * f_out)208*2abb3134SXin Li bool WriteTag(const char* tag, FILE* f_out) {
209*2abb3134SXin Li assert(strlen(tag) == 3); // write 3 byte tags with NUL byte
210*2abb3134SXin Li return fwrite(tag, 1, 4, f_out) == 4;
211*2abb3134SXin Li }
212*2abb3134SXin Li
213*2abb3134SXin Li // Write the probabilities as a flat list of doubles. The caller knows what
214*2abb3134SXin Li // the dimensions are.
WriteResult(const vector<double> & pij,uint32_t num_em_iters,FILE * f_out)215*2abb3134SXin Li bool WriteResult(const vector<double>& pij, uint32_t num_em_iters,
216*2abb3134SXin Li FILE* f_out) {
217*2abb3134SXin Li if (!WriteTag("emi", f_out)) {
218*2abb3134SXin Li return false;
219*2abb3134SXin Li }
220*2abb3134SXin Li if (fwrite(&num_em_iters, sizeof num_em_iters, 1, f_out) != 1) {
221*2abb3134SXin Li return false;
222*2abb3134SXin Li }
223*2abb3134SXin Li
224*2abb3134SXin Li if (!WriteTag("pij", f_out)) {
225*2abb3134SXin Li return false;
226*2abb3134SXin Li }
227*2abb3134SXin Li size_t n = pij.size();
228*2abb3134SXin Li if (fwrite(&pij[0], sizeof pij[0], n, f_out) != n) {
229*2abb3134SXin Li return false;
230*2abb3134SXin Li }
231*2abb3134SXin Li return true;
232*2abb3134SXin Li }
233*2abb3134SXin Li
234*2abb3134SXin Li // Like atoi, but with basic (not exhaustive) error checking.
StringToInt(const char * s,int * result)235*2abb3134SXin Li bool StringToInt(const char* s, int* result) {
236*2abb3134SXin Li bool ok = true;
237*2abb3134SXin Li char* end; // mutated by strtol
238*2abb3134SXin Li
239*2abb3134SXin Li *result = strtol(s, &end, 10); // base 10
240*2abb3134SXin Li // If strol didn't consume any characters, it failed.
241*2abb3134SXin Li if (end == s) {
242*2abb3134SXin Li ok = false;
243*2abb3134SXin Li }
244*2abb3134SXin Li return ok;
245*2abb3134SXin Li }
246*2abb3134SXin Li
main(int argc,char ** argv)247*2abb3134SXin Li int main(int argc, char **argv) {
248*2abb3134SXin Li if (argc < 4) {
249*2abb3134SXin Li log("Usage: read_numeric INPUT OUTPUT max_em_iters");
250*2abb3134SXin Li return 1;
251*2abb3134SXin Li }
252*2abb3134SXin Li
253*2abb3134SXin Li char* in_filename = argv[1];
254*2abb3134SXin Li char* out_filename = argv[2];
255*2abb3134SXin Li
256*2abb3134SXin Li int max_em_iters;
257*2abb3134SXin Li if (!StringToInt(argv[3], &max_em_iters)) {
258*2abb3134SXin Li log("Error parsing max_em_iters");
259*2abb3134SXin Li return 1;
260*2abb3134SXin Li }
261*2abb3134SXin Li
262*2abb3134SXin Li FILE* f = fopen(in_filename, "rb");
263*2abb3134SXin Li if (f == NULL) {
264*2abb3134SXin Li return 1;
265*2abb3134SXin Li }
266*2abb3134SXin Li
267*2abb3134SXin Li // Try opening first so we don't do a long computation and then fail.
268*2abb3134SXin Li FILE* f_out = fopen(out_filename, "wb");
269*2abb3134SXin Li if (f_out == NULL) {
270*2abb3134SXin Li return 1;
271*2abb3134SXin Li }
272*2abb3134SXin Li
273*2abb3134SXin Li uint32_t num_entries;
274*2abb3134SXin Li uint32_t entry_size;
275*2abb3134SXin Li vector<double> cond_prob;
276*2abb3134SXin Li if (!ReadListOfMatrices(f, &num_entries, &entry_size, &cond_prob)) {
277*2abb3134SXin Li log("Error reading list of matrices");
278*2abb3134SXin Li return 1;
279*2abb3134SXin Li }
280*2abb3134SXin Li
281*2abb3134SXin Li fclose(f);
282*2abb3134SXin Li
283*2abb3134SXin Li // Sanity check
284*2abb3134SXin Li double debug_sum = 0.0;
285*2abb3134SXin Li for (size_t m = 0; m < num_entries; ++m) {
286*2abb3134SXin Li // base index for the matrix corresponding to a report.
287*2abb3134SXin Li size_t c_base = m * entry_size;
288*2abb3134SXin Li for (size_t i = 0; i < entry_size; ++i) { // multiply and running sum
289*2abb3134SXin Li debug_sum += cond_prob[c_base + i];
290*2abb3134SXin Li }
291*2abb3134SXin Li }
292*2abb3134SXin Li log("Debug sum: %f", debug_sum);
293*2abb3134SXin Li
294*2abb3134SXin Li double epsilon = 1e-6;
295*2abb3134SXin Li log("epsilon: %f", epsilon);
296*2abb3134SXin Li
297*2abb3134SXin Li vector<double> pij(entry_size);
298*2abb3134SXin Li int num_em_iters = ExpectationMaximization(
299*2abb3134SXin Li num_entries, entry_size, cond_prob, max_em_iters, epsilon, &pij);
300*2abb3134SXin Li
301*2abb3134SXin Li if (!WriteResult(pij, num_em_iters, f_out)) {
302*2abb3134SXin Li log("Error writing result matrix");
303*2abb3134SXin Li return 1;
304*2abb3134SXin Li }
305*2abb3134SXin Li fclose(f_out);
306*2abb3134SXin Li
307*2abb3134SXin Li log("fast EM done");
308*2abb3134SXin Li return 0;
309*2abb3134SXin Li }
310