1#!/usr/bin/python 2""" 3fast_em.py: Tensorflow implementation of expectation maximization for RAPPOR 4association analysis. 5 6TODO: 7 - Use TensorFlow ops for reading input (so that reading input can be 8 distributed) 9 - Reduce the number of ops (currently proportional to the number of reports). 10 May require new TensorFlow ops. 11 - Fix performance bug (v_split is probably being recomputed on every 12 iteration): 13 bin$ ./test.sh decode-assoc-cpp - 1.1 seconds (single-threaded C++) 14 bin$ ./test.sh decode-assoc-tensorflow - 226 seconds on GPU 15""" 16 17import sys 18 19import numpy as np 20import tensorflow as tf 21 22 23def log(msg, *args): 24 if args: 25 msg = msg % args 26 print >>sys.stderr, msg 27 28 29def ExpectTag(f, expected): 30 """Read and consume a 4 byte tag from the given file.""" 31 b = f.read(4) 32 if b != expected: 33 raise RuntimeError('Expected %r, got %r' % (expected, b)) 34 35 36def ReadListOfMatrices(f): 37 """ 38 Read a big list of conditional probability matrices from a binary file. 39 """ 40 ExpectTag(f, 'ne \0') 41 num_entries = np.fromfile(f, np.uint32, count=1)[0] 42 log('Number of entries: %d', num_entries) 43 44 ExpectTag(f, 'es \0') 45 entry_size = np.fromfile(f, np.uint32, count=1)[0] 46 log('Entry size: %d', entry_size) 47 48 ExpectTag(f, 'dat\0') 49 vec_length = num_entries * entry_size 50 v = np.fromfile(f, np.float64, count=vec_length) 51 52 log('Values read: %d', len(v)) 53 log('v: %s', v[:10]) 54 #print 'SUM', sum(v) 55 56 # NOTE: We're not reshaping because we're using one TensorFlow tensor object 57 # per matrix, since it makes the algorithm expressible with current 58 # TensorFlow ops. 59 #v = v.reshape((num_entries, entry_size)) 60 61 return num_entries, entry_size, v 62 63 64def WriteTag(f, tag): 65 if len(tag) != 3: 66 raise AssertionError("Tags should be 3 bytes. Got %r" % tag) 67 f.write(tag + '\0') # NUL terminated 68 69 70def WriteResult(f, num_em_iters, pij): 71 WriteTag(f, 'emi') 72 emi = np.array([num_em_iters], np.uint32) 73 emi.tofile(f) 74 75 WriteTag(f, 'pij') 76 pij.tofile(f) 77 78 79def DebugSum(num_entries, entry_size, v): 80 """Sum the entries as a sanity check.""" 81 cond_prob = tf.placeholder(tf.float64, shape=(num_entries * entry_size,)) 82 debug_sum = tf.reduce_sum(cond_prob) 83 with tf.Session() as sess: 84 s = sess.run(debug_sum, feed_dict={cond_prob: v}) 85 log('Debug sum: %f', s) 86 87 88def BuildEmIter(num_entries, entry_size, v): 89 # Placeholder for the value from the previous iteration. 90 pij_in = tf.placeholder(tf.float64, shape=(entry_size,)) 91 92 # split along dimension 0 93 # TODO: 94 # - make sure this doesn't get run for every EM iteration 95 # - investigate using tf.tile() instead? (this may cost more memory) 96 v_split = tf.split(0, num_entries, v) 97 98 z_numerator = [report * pij_in for report in v_split] 99 sum_z = [tf.reduce_sum(report) for report in z_numerator] 100 z = [z_numerator[i] / sum_z[i] for i in xrange(num_entries)] 101 102 # Concat per-report tensors and reshape. This is probably inefficient? 103 z_concat = tf.concat(0, z) 104 z_concat = tf.reshape(z_concat, [num_entries, entry_size]) 105 106 # This whole expression represents an EM iteration. Bind the pij_in 107 # placeholder, and get a new estimation of Pij. 108 em_iter_expr = tf.reduce_sum(z_concat, 0) / num_entries 109 110 return pij_in, em_iter_expr 111 112 113def RunEm(pij_in, entry_size, em_iter_expr, max_em_iters, epsilon=1e-6): 114 """Run the iterative EM algorithm (using the TensorFlow API). 115 116 Args: 117 num_entries: number of matrices (one per report) 118 entry_size: total number of cells in each matrix 119 v: numpy.ndarray (e.g. 7000 x 8 matrix) 120 max_em_iters: maximum number of EM iterations 121 122 Returns: 123 pij: numpy.ndarray (e.g. vector of length 8) 124 """ 125 # Initial value is the uniform distribution 126 pij = np.ones(entry_size) / entry_size 127 128 i = 0 # visible outside loop 129 130 # Do EM iterations. 131 with tf.Session() as sess: 132 for i in xrange(max_em_iters): 133 print 'PIJ', pij 134 new_pij = sess.run(em_iter_expr, feed_dict={pij_in: pij}) 135 dif = max(abs(new_pij - pij)) 136 log('EM iteration %d, dif = %e', i, dif) 137 pij = new_pij 138 139 if dif < epsilon: 140 log('Early EM termination: %e < %e', max_dif, epsilon) 141 break 142 143 # If i = 9, then we did 10 iteratinos. 144 return i + 1, pij 145 146 147def sep(): 148 print '-' * 80 149 150 151def main(argv): 152 input_path = argv[1] 153 output_path = argv[2] 154 max_em_iters = int(argv[3]) 155 156 sep() 157 with open(input_path) as f: 158 num_entries, entry_size, cond_prob = ReadListOfMatrices(f) 159 160 sep() 161 DebugSum(num_entries, entry_size, cond_prob) 162 163 sep() 164 pij_in, em_iter_expr = BuildEmIter(num_entries, entry_size, cond_prob) 165 num_em_iters, pij = RunEm(pij_in, entry_size, em_iter_expr, max_em_iters) 166 167 sep() 168 log('Final Pij: %s', pij) 169 170 with open(output_path, 'wb') as f: 171 WriteResult(f, num_em_iters, pij) 172 log('Wrote %s', output_path) 173 174 175if __name__ == '__main__': 176 try: 177 main(sys.argv) 178 except RuntimeError, e: 179 print >>sys.stderr, 'FATAL: %s' % e 180 sys.exit(1) 181