xref: /aosp_15_r20/external/rappor/analysis/tensorflow/fast_em.py (revision 2abb31345f6c95944768b5222a9a5ed3fc68cc00)
1#!/usr/bin/python
2"""
3fast_em.py: Tensorflow implementation of expectation maximization for RAPPOR
4association analysis.
5
6TODO:
7  - Use TensorFlow ops for reading input (so that reading input can be
8    distributed)
9  - Reduce the number of ops (currently proportional to the number of reports).
10    May require new TensorFlow ops.
11  - Fix performance bug (v_split is probably being recomputed on every
12    iteration):
13    bin$ ./test.sh decode-assoc-cpp - 1.1 seconds (single-threaded C++)
14    bin$ ./test.sh decode-assoc-tensorflow - 226 seconds on GPU
15"""
16
17import sys
18
19import numpy as np
20import tensorflow as tf
21
22
23def log(msg, *args):
24  if args:
25    msg = msg % args
26  print >>sys.stderr, msg
27
28
29def ExpectTag(f, expected):
30  """Read and consume a 4 byte tag from the given file."""
31  b = f.read(4)
32  if b != expected:
33    raise RuntimeError('Expected %r, got %r' % (expected, b))
34
35
36def ReadListOfMatrices(f):
37  """
38  Read a big list of conditional probability matrices from a binary file.
39  """
40  ExpectTag(f, 'ne \0')
41  num_entries = np.fromfile(f, np.uint32, count=1)[0]
42  log('Number of entries: %d', num_entries)
43
44  ExpectTag(f, 'es \0')
45  entry_size = np.fromfile(f, np.uint32, count=1)[0]
46  log('Entry size: %d', entry_size)
47
48  ExpectTag(f, 'dat\0')
49  vec_length = num_entries * entry_size
50  v = np.fromfile(f, np.float64, count=vec_length)
51
52  log('Values read: %d', len(v))
53  log('v: %s', v[:10])
54  #print 'SUM', sum(v)
55
56  # NOTE: We're not reshaping because we're using one TensorFlow tensor object
57  # per matrix, since it makes the algorithm expressible with current
58  # TensorFlow ops.
59  #v = v.reshape((num_entries, entry_size))
60
61  return num_entries, entry_size, v
62
63
64def WriteTag(f, tag):
65  if len(tag) != 3:
66    raise AssertionError("Tags should be 3 bytes.  Got %r" % tag)
67  f.write(tag + '\0')  # NUL terminated
68
69
70def WriteResult(f, num_em_iters, pij):
71  WriteTag(f, 'emi')
72  emi = np.array([num_em_iters], np.uint32)
73  emi.tofile(f)
74
75  WriteTag(f, 'pij')
76  pij.tofile(f)
77
78
79def DebugSum(num_entries, entry_size, v):
80  """Sum the entries as a sanity check."""
81  cond_prob = tf.placeholder(tf.float64, shape=(num_entries * entry_size,))
82  debug_sum = tf.reduce_sum(cond_prob)
83  with tf.Session() as sess:
84    s = sess.run(debug_sum, feed_dict={cond_prob: v})
85  log('Debug sum: %f', s)
86
87
88def BuildEmIter(num_entries, entry_size, v):
89  # Placeholder for the value from the previous iteration.
90  pij_in = tf.placeholder(tf.float64, shape=(entry_size,))
91
92  # split along dimension 0
93  # TODO:
94  # - make sure this doesn't get run for every EM iteration
95  # - investigate using tf.tile() instead?  (this may cost more memory)
96  v_split = tf.split(0, num_entries, v)
97
98  z_numerator = [report * pij_in for report in v_split]
99  sum_z = [tf.reduce_sum(report) for report in z_numerator]
100  z = [z_numerator[i] / sum_z[i] for i in xrange(num_entries)]
101
102  # Concat per-report tensors and reshape.  This is probably inefficient?
103  z_concat = tf.concat(0, z)
104  z_concat = tf.reshape(z_concat, [num_entries, entry_size])
105
106  # This whole expression represents an EM iteration.  Bind the pij_in
107  # placeholder, and get a new estimation of Pij.
108  em_iter_expr = tf.reduce_sum(z_concat, 0) / num_entries
109
110  return pij_in, em_iter_expr
111
112
113def RunEm(pij_in, entry_size, em_iter_expr, max_em_iters, epsilon=1e-6):
114  """Run the iterative EM algorithm (using the TensorFlow API).
115
116  Args:
117    num_entries: number of matrices (one per report)
118    entry_size: total number of cells in each matrix
119    v: numpy.ndarray (e.g. 7000 x 8 matrix)
120    max_em_iters: maximum number of EM iterations
121
122  Returns:
123    pij: numpy.ndarray (e.g. vector of length 8)
124  """
125  # Initial value is the uniform distribution
126  pij = np.ones(entry_size) / entry_size
127
128  i = 0  # visible outside loop
129
130  # Do EM iterations.
131  with tf.Session() as sess:
132    for i in xrange(max_em_iters):
133      print 'PIJ', pij
134      new_pij = sess.run(em_iter_expr, feed_dict={pij_in: pij})
135      dif = max(abs(new_pij - pij))
136      log('EM iteration %d, dif = %e', i, dif)
137      pij = new_pij
138
139      if dif < epsilon:
140        log('Early EM termination: %e < %e', max_dif, epsilon)
141        break
142
143  # If i = 9, then we did 10 iteratinos.
144  return i + 1, pij
145
146
147def sep():
148  print '-' * 80
149
150
151def main(argv):
152  input_path = argv[1]
153  output_path = argv[2]
154  max_em_iters = int(argv[3])
155
156  sep()
157  with open(input_path) as f:
158    num_entries, entry_size, cond_prob = ReadListOfMatrices(f)
159
160  sep()
161  DebugSum(num_entries, entry_size, cond_prob)
162
163  sep()
164  pij_in, em_iter_expr = BuildEmIter(num_entries, entry_size, cond_prob)
165  num_em_iters, pij = RunEm(pij_in, entry_size, em_iter_expr, max_em_iters)
166
167  sep()
168  log('Final Pij: %s', pij)
169
170  with open(output_path, 'wb') as f:
171    WriteResult(f, num_em_iters, pij)
172  log('Wrote %s', output_path)
173
174
175if __name__ == '__main__':
176  try:
177    main(sys.argv)
178  except RuntimeError, e:
179    print >>sys.stderr, 'FATAL: %s' % e
180    sys.exit(1)
181