1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for plan_utils.""" 15 16import functools 17import tempfile 18from typing import Any, Optional 19 20from absl.testing import absltest 21import tensorflow as tf 22 23from fcp.demo import plan_utils 24from fcp.demo import test_utils 25from fcp.protos import plan_pb2 26from fcp.tensorflow import serve_slices 27 28DEFAULT_INITIAL_CHECKPOINT = b'initial' 29CHECKPOINT_TENSOR_NAME = 'checkpoint' 30INTERMEDIATE_TENSOR_NAME = 'intermediate_value' 31FINAL_TENSOR_NAME = 'final_value' 32NUM_SLICES = 3 33 34 35def create_plan(log_file: Optional[str] = None) -> plan_pb2.Plan: 36 """Creates a test Plan that sums inputs.""" 37 38 def log_op(name: str) -> tf.Operation: 39 """Helper function to log op invocations to a file.""" 40 if log_file: 41 return tf.print(name, output_stream=f'file://{log_file}') 42 return tf.raw_ops.NoOp() 43 44 def create_checkpoint_op( 45 name: str, 46 filename_op: Any, 47 save_op: Any = None, 48 restore_op: Any = None, 49 session_token_tensor_name: Optional[str] = None, 50 ) -> plan_pb2.CheckpointOp: 51 before_restore = log_op(f'{name}/before_restore') 52 after_restore = log_op(f'{name}/after_restore') 53 before_save = log_op(f'{name}/before_save') 54 after_save = log_op(f'{name}/after_save') 55 with tf.control_dependencies( 56 [save_op if save_op is not None else tf.raw_ops.NoOp()]): 57 save_op = log_op(f'{name}/save') 58 with tf.control_dependencies( 59 [restore_op if restore_op is not None else tf.raw_ops.NoOp()]): 60 restore_op = log_op(f'{name}/restore') 61 return plan_pb2.CheckpointOp( 62 saver_def=tf.compat.v1.train.SaverDef( 63 filename_tensor_name=filename_op.name, 64 restore_op_name=restore_op.name, 65 save_tensor_name=save_op.name, 66 version=tf.compat.v1.train.SaverDef.V1, 67 ), 68 before_restore_op=before_restore.name, 69 after_restore_op=after_restore.name, 70 before_save_op=before_save.name, 71 after_save_op=after_save.name, 72 session_token_tensor_name=session_token_tensor_name, 73 ) 74 75 with tf.compat.v1.Graph().as_default() as client_graph: 76 tf.constant(0) 77 78 with tf.compat.v1.Graph().as_default() as server_graph: 79 # Initialization: 80 last_client_update = tf.Variable(0, dtype=tf.int32) 81 intermediate_acc = tf.Variable(0, dtype=tf.int32) 82 last_intermediate_update = tf.Variable(0, dtype=tf.int32) 83 final_acc = tf.Variable(0, dtype=tf.int32) 84 with tf.control_dependencies([ 85 last_client_update.initializer, intermediate_acc.initializer, 86 last_intermediate_update.initializer, final_acc.initializer 87 ]): 88 phase_init_op = log_op('phase_init') 89 90 # Ops for Federated Select: 91 select_fn_initialize_op = log_op('slices/initialize') 92 select_fn_server_vals = [ 93 tf.constant(1234), 94 tf.constant('asdf'), 95 tf.constant([1, 2, 3]), 96 ] 97 select_fn_server_val_inputs = [ 98 tf.compat.v1.placeholder(v.dtype) for v in select_fn_server_vals 99 ] 100 select_fn_key_input = tf.compat.v1.placeholder(tf.int32, shape=()) 101 select_fn_filename_input = tf.compat.v1.placeholder(tf.string, shape=()) 102 assertions = [ 103 tf.debugging.assert_equal(placeholder, constant) 104 for placeholder, constant in zip( 105 select_fn_server_val_inputs, select_fn_server_vals 106 ) 107 ] 108 with tf.control_dependencies([log_op('slices/save_slice')] + assertions): 109 select_fn_save_op = tf.io.write_file( 110 select_fn_filename_input, tf.strings.as_string(select_fn_key_input) 111 ) 112 # Some tests disable passing the callback token; set `served_at_id` to '-' 113 # in that case. 114 callback_token = tf.compat.v1.placeholder_with_default('', shape=()) 115 served_at_id = tf.cond( 116 tf.equal(callback_token, ''), 117 lambda: '-', 118 functools.partial( 119 serve_slices.serve_slices, 120 callback_token=callback_token, 121 server_val=select_fn_server_vals, 122 max_key=NUM_SLICES - 1, 123 select_fn_initialize_op=select_fn_initialize_op.name, 124 select_fn_server_val_input_tensor_names=[ 125 v.name for v in select_fn_server_val_inputs 126 ], 127 select_fn_key_input_tensor_name=select_fn_key_input.name, 128 select_fn_filename_input_tensor_name=select_fn_filename_input.name, 129 select_fn_target_tensor_name=select_fn_save_op.name, 130 ), 131 ) 132 133 # Ops for L2 Aggregation: 134 client_checkpoint_data = tf.Variable( 135 DEFAULT_INITIAL_CHECKPOINT, dtype=tf.string) 136 137 write_client_init_filename = tf.compat.v1.placeholder(tf.string, shape=()) 138 client_checkpoint_data_value = tf.cond( 139 tf.compat.v1.is_variable_initialized(client_checkpoint_data), 140 client_checkpoint_data.read_value, 141 lambda: client_checkpoint_data.initial_value, 142 ) 143 write_client_init_op = create_checkpoint_op( 144 'write_client_init', 145 write_client_init_filename, 146 save_op=tf.io.write_file( 147 write_client_init_filename, 148 tf.strings.join( 149 [client_checkpoint_data_value, served_at_id], separator=' ' 150 ), 151 ), 152 session_token_tensor_name=callback_token.name, 153 ) 154 155 read_intermediate_update_filename = tf.compat.v1.placeholder( 156 tf.string, shape=()) 157 read_intermediate_update_op = create_checkpoint_op( 158 'read_intermediate_update', 159 read_intermediate_update_filename, 160 restore_op=last_intermediate_update.assign( 161 tf.raw_ops.Restore( 162 file_pattern=read_intermediate_update_filename, 163 tensor_name=INTERMEDIATE_TENSOR_NAME, 164 dt=tf.int32))) 165 166 with tf.control_dependencies([log_op('apply_aggregated_updates')]): 167 apply_aggregated_updates_op = final_acc.assign_add( 168 last_intermediate_update) 169 170 server_savepoint_filename = tf.compat.v1.placeholder(tf.string, shape=()) 171 server_savepoint_op = create_checkpoint_op( 172 'server_savepoint', 173 server_savepoint_filename, 174 save_op=tf.raw_ops.Save( 175 filename=server_savepoint_filename, 176 tensor_names=[FINAL_TENSOR_NAME], 177 data=[final_acc]), 178 restore_op=client_checkpoint_data.assign( 179 tf.raw_ops.Restore( 180 file_pattern=server_savepoint_filename, 181 tensor_name=CHECKPOINT_TENSOR_NAME, 182 dt=tf.string))) 183 184 config_proto = tf.compat.v1.ConfigProto(operation_timeout_in_ms=1234) 185 186 plan = plan_pb2.Plan( 187 phase=[ 188 plan_pb2.Plan.Phase( 189 client_phase=plan_pb2.ClientPhase(name='ClientPhase'), 190 server_phase=plan_pb2.ServerPhase( 191 phase_init_op=phase_init_op.name, 192 write_client_init=write_client_init_op, 193 read_intermediate_update=read_intermediate_update_op, 194 apply_aggregrated_updates_op=( 195 apply_aggregated_updates_op.name))) 196 ], 197 server_savepoint=server_savepoint_op, 198 client_tflite_graph_bytes=b'tflite-graph', 199 version=1) 200 plan.client_graph_bytes.Pack(client_graph.as_graph_def()) 201 plan.server_graph_bytes.Pack(server_graph.as_graph_def()) 202 plan.tensorflow_config_proto.Pack(config_proto) 203 return plan 204 205 206def create_checkpoint(tensor_name=b'test'): 207 """Creates a test initial checkpoint.""" 208 return test_utils.create_checkpoint({CHECKPOINT_TENSOR_NAME: tensor_name}) 209 210 211class PlanUtilsTest(absltest.TestCase): 212 213 def test_session_enter_exit(self): 214 self.assertIsNone(tf.compat.v1.get_default_session()) 215 with plan_utils.Session(create_plan(), create_checkpoint()): 216 self.assertIsNotNone(tf.compat.v1.get_default_session()) 217 self.assertIsNone(tf.compat.v1.get_default_session()) 218 219 def test_session_without_phase(self): 220 plan = create_plan() 221 plan.ClearField('phase') 222 with self.assertRaises(ValueError): 223 plan_utils.Session(plan, create_checkpoint()) 224 225 def test_session_without_server_phase(self): 226 plan = create_plan() 227 plan.phase[0].ClearField('server_phase') 228 with self.assertRaises(ValueError): 229 plan_utils.Session(plan, create_checkpoint()) 230 231 def test_session_with_multiple_phases(self): 232 plan = create_plan() 233 plan.phase.append(plan.phase[0]) 234 with self.assertRaises(ValueError): 235 plan_utils.Session(plan, create_checkpoint()) 236 237 def test_session_client_plan(self): 238 plan = create_plan() 239 with plan_utils.Session(plan, create_checkpoint()) as session: 240 self.assertEqual( 241 plan_pb2.ClientOnlyPlan.FromString(session.client_plan), 242 plan_pb2.ClientOnlyPlan( 243 phase=plan.phase[0].client_phase, 244 graph=plan.client_graph_bytes.value, 245 tflite_graph=plan.client_tflite_graph_bytes, 246 tensorflow_config_proto=plan.tensorflow_config_proto)) 247 248 def test_session_client_plan_without_tensorflow_config(self): 249 plan = create_plan() 250 plan.ClearField('tensorflow_config_proto') 251 with plan_utils.Session(plan, create_checkpoint()) as session: 252 self.assertEqual( 253 plan_pb2.ClientOnlyPlan.FromString(session.client_plan), 254 plan_pb2.ClientOnlyPlan( 255 phase=plan.phase[0].client_phase, 256 graph=plan.client_graph_bytes.value, 257 tflite_graph=plan.client_tflite_graph_bytes)) 258 259 def test_session_client_plan_without_tflite_graph(self): 260 plan = create_plan() 261 plan.ClearField('client_tflite_graph_bytes') 262 with plan_utils.Session(plan, create_checkpoint()) as session: 263 self.assertEqual( 264 plan_pb2.ClientOnlyPlan.FromString(session.client_plan), 265 plan_pb2.ClientOnlyPlan( 266 phase=plan.phase[0].client_phase, 267 graph=plan.client_graph_bytes.value, 268 tensorflow_config_proto=plan.tensorflow_config_proto)) 269 270 def test_session_client_checkpoint(self): 271 expected = b'test-client-checkpoint' 272 with plan_utils.Session( 273 create_plan(), 274 test_utils.create_checkpoint({CHECKPOINT_TENSOR_NAME: expected 275 })) as session: 276 self.assertEqual( 277 session.client_checkpoint, 278 expected + b' ' + next(iter(session.slices)).encode(), 279 ) 280 281 def test_session_client_checkpoint_without_server_savepoint(self): 282 plan = create_plan() 283 # If server_savepoint isn't set, the checkpoint shouldn't be loaded. 284 plan.ClearField('server_savepoint') 285 with plan_utils.Session(plan, create_checkpoint()) as session: 286 self.assertStartsWith( 287 session.client_checkpoint, DEFAULT_INITIAL_CHECKPOINT + b' ' 288 ) 289 290 def test_session_finalize(self): 291 with tempfile.NamedTemporaryFile('r') as tmpfile: 292 with plan_utils.Session(create_plan(tmpfile.name), 293 create_checkpoint()) as session: 294 checkpoint = session.finalize( 295 test_utils.create_checkpoint({INTERMEDIATE_TENSOR_NAME: 3})) 296 self.assertSequenceEqual( 297 tmpfile.read().splitlines(), 298 [ 299 'server_savepoint/before_restore', 300 'server_savepoint/restore', 301 'server_savepoint/after_restore', 302 'phase_init', 303 'write_client_init/before_save', 304 'write_client_init/save', 305 'write_client_init/after_save', 306 ] 307 + ['slices/initialize', 'slices/save_slice'] * NUM_SLICES 308 + [ 309 'read_intermediate_update/before_restore', 310 'read_intermediate_update/restore', 311 'read_intermediate_update/after_restore', 312 'apply_aggregated_updates', 313 'server_savepoint/before_save', 314 'server_savepoint/save', 315 'server_savepoint/after_save', 316 ], 317 ) 318 319 result = test_utils.read_tensor_from_checkpoint(checkpoint, 320 FINAL_TENSOR_NAME, tf.int32) 321 # The value should be propagated from the intermediate aggregate. 322 self.assertEqual(result, 3) 323 324 def test_session_with_tensorflow_error(self): 325 plan = create_plan() 326 plan.phase[0].server_phase.phase_init_op = 'does-not-exist' 327 with self.assertRaises(ValueError): 328 plan_utils.Session(plan, create_checkpoint()) 329 330 def test_session_slices(self): 331 with plan_utils.Session(create_plan(), create_checkpoint()) as session: 332 # The served_at_id should match the value in the client checkpoint. 333 served_at_id = session.client_checkpoint.split(b' ')[1].decode() 334 self.assertSameElements(session.slices.keys(), [served_at_id]) 335 self.assertListEqual( 336 session.slices[served_at_id], 337 [str(i).encode() for i in range(NUM_SLICES)], 338 ) 339 340 def test_session_without_slices(self): 341 plan = create_plan() 342 plan.phase[0].server_phase.write_client_init.ClearField( 343 'session_token_tensor_name' 344 ) 345 with plan_utils.Session(plan, create_checkpoint()) as session: 346 self.assertEmpty(session.slices) 347 348 349if __name__ == '__main__': 350 absltest.main() 351