xref: /aosp_15_r20/external/federated-compute/fcp/demo/plan_utils_test.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for plan_utils."""
15
16import functools
17import tempfile
18from typing import Any, Optional
19
20from absl.testing import absltest
21import tensorflow as tf
22
23from fcp.demo import plan_utils
24from fcp.demo import test_utils
25from fcp.protos import plan_pb2
26from fcp.tensorflow import serve_slices
27
28DEFAULT_INITIAL_CHECKPOINT = b'initial'
29CHECKPOINT_TENSOR_NAME = 'checkpoint'
30INTERMEDIATE_TENSOR_NAME = 'intermediate_value'
31FINAL_TENSOR_NAME = 'final_value'
32NUM_SLICES = 3
33
34
35def create_plan(log_file: Optional[str] = None) -> plan_pb2.Plan:
36  """Creates a test Plan that sums inputs."""
37
38  def log_op(name: str) -> tf.Operation:
39    """Helper function to log op invocations to a file."""
40    if log_file:
41      return tf.print(name, output_stream=f'file://{log_file}')
42    return tf.raw_ops.NoOp()
43
44  def create_checkpoint_op(
45      name: str,
46      filename_op: Any,
47      save_op: Any = None,
48      restore_op: Any = None,
49      session_token_tensor_name: Optional[str] = None,
50  ) -> plan_pb2.CheckpointOp:
51    before_restore = log_op(f'{name}/before_restore')
52    after_restore = log_op(f'{name}/after_restore')
53    before_save = log_op(f'{name}/before_save')
54    after_save = log_op(f'{name}/after_save')
55    with tf.control_dependencies(
56        [save_op if save_op is not None else tf.raw_ops.NoOp()]):
57      save_op = log_op(f'{name}/save')
58    with tf.control_dependencies(
59        [restore_op if restore_op is not None else tf.raw_ops.NoOp()]):
60      restore_op = log_op(f'{name}/restore')
61    return plan_pb2.CheckpointOp(
62        saver_def=tf.compat.v1.train.SaverDef(
63            filename_tensor_name=filename_op.name,
64            restore_op_name=restore_op.name,
65            save_tensor_name=save_op.name,
66            version=tf.compat.v1.train.SaverDef.V1,
67        ),
68        before_restore_op=before_restore.name,
69        after_restore_op=after_restore.name,
70        before_save_op=before_save.name,
71        after_save_op=after_save.name,
72        session_token_tensor_name=session_token_tensor_name,
73    )
74
75  with tf.compat.v1.Graph().as_default() as client_graph:
76    tf.constant(0)
77
78  with tf.compat.v1.Graph().as_default() as server_graph:
79    # Initialization:
80    last_client_update = tf.Variable(0, dtype=tf.int32)
81    intermediate_acc = tf.Variable(0, dtype=tf.int32)
82    last_intermediate_update = tf.Variable(0, dtype=tf.int32)
83    final_acc = tf.Variable(0, dtype=tf.int32)
84    with tf.control_dependencies([
85        last_client_update.initializer, intermediate_acc.initializer,
86        last_intermediate_update.initializer, final_acc.initializer
87    ]):
88      phase_init_op = log_op('phase_init')
89
90    # Ops for Federated Select:
91    select_fn_initialize_op = log_op('slices/initialize')
92    select_fn_server_vals = [
93        tf.constant(1234),
94        tf.constant('asdf'),
95        tf.constant([1, 2, 3]),
96    ]
97    select_fn_server_val_inputs = [
98        tf.compat.v1.placeholder(v.dtype) for v in select_fn_server_vals
99    ]
100    select_fn_key_input = tf.compat.v1.placeholder(tf.int32, shape=())
101    select_fn_filename_input = tf.compat.v1.placeholder(tf.string, shape=())
102    assertions = [
103        tf.debugging.assert_equal(placeholder, constant)
104        for placeholder, constant in zip(
105            select_fn_server_val_inputs, select_fn_server_vals
106        )
107    ]
108    with tf.control_dependencies([log_op('slices/save_slice')] + assertions):
109      select_fn_save_op = tf.io.write_file(
110          select_fn_filename_input, tf.strings.as_string(select_fn_key_input)
111      )
112    # Some tests disable passing the callback token; set `served_at_id` to '-'
113    # in that case.
114    callback_token = tf.compat.v1.placeholder_with_default('', shape=())
115    served_at_id = tf.cond(
116        tf.equal(callback_token, ''),
117        lambda: '-',
118        functools.partial(
119            serve_slices.serve_slices,
120            callback_token=callback_token,
121            server_val=select_fn_server_vals,
122            max_key=NUM_SLICES - 1,
123            select_fn_initialize_op=select_fn_initialize_op.name,
124            select_fn_server_val_input_tensor_names=[
125                v.name for v in select_fn_server_val_inputs
126            ],
127            select_fn_key_input_tensor_name=select_fn_key_input.name,
128            select_fn_filename_input_tensor_name=select_fn_filename_input.name,
129            select_fn_target_tensor_name=select_fn_save_op.name,
130        ),
131    )
132
133    # Ops for L2 Aggregation:
134    client_checkpoint_data = tf.Variable(
135        DEFAULT_INITIAL_CHECKPOINT, dtype=tf.string)
136
137    write_client_init_filename = tf.compat.v1.placeholder(tf.string, shape=())
138    client_checkpoint_data_value = tf.cond(
139        tf.compat.v1.is_variable_initialized(client_checkpoint_data),
140        client_checkpoint_data.read_value,
141        lambda: client_checkpoint_data.initial_value,
142    )
143    write_client_init_op = create_checkpoint_op(
144        'write_client_init',
145        write_client_init_filename,
146        save_op=tf.io.write_file(
147            write_client_init_filename,
148            tf.strings.join(
149                [client_checkpoint_data_value, served_at_id], separator=' '
150            ),
151        ),
152        session_token_tensor_name=callback_token.name,
153    )
154
155    read_intermediate_update_filename = tf.compat.v1.placeholder(
156        tf.string, shape=())
157    read_intermediate_update_op = create_checkpoint_op(
158        'read_intermediate_update',
159        read_intermediate_update_filename,
160        restore_op=last_intermediate_update.assign(
161            tf.raw_ops.Restore(
162                file_pattern=read_intermediate_update_filename,
163                tensor_name=INTERMEDIATE_TENSOR_NAME,
164                dt=tf.int32)))
165
166    with tf.control_dependencies([log_op('apply_aggregated_updates')]):
167      apply_aggregated_updates_op = final_acc.assign_add(
168          last_intermediate_update)
169
170    server_savepoint_filename = tf.compat.v1.placeholder(tf.string, shape=())
171    server_savepoint_op = create_checkpoint_op(
172        'server_savepoint',
173        server_savepoint_filename,
174        save_op=tf.raw_ops.Save(
175            filename=server_savepoint_filename,
176            tensor_names=[FINAL_TENSOR_NAME],
177            data=[final_acc]),
178        restore_op=client_checkpoint_data.assign(
179            tf.raw_ops.Restore(
180                file_pattern=server_savepoint_filename,
181                tensor_name=CHECKPOINT_TENSOR_NAME,
182                dt=tf.string)))
183
184  config_proto = tf.compat.v1.ConfigProto(operation_timeout_in_ms=1234)
185
186  plan = plan_pb2.Plan(
187      phase=[
188          plan_pb2.Plan.Phase(
189              client_phase=plan_pb2.ClientPhase(name='ClientPhase'),
190              server_phase=plan_pb2.ServerPhase(
191                  phase_init_op=phase_init_op.name,
192                  write_client_init=write_client_init_op,
193                  read_intermediate_update=read_intermediate_update_op,
194                  apply_aggregrated_updates_op=(
195                      apply_aggregated_updates_op.name)))
196      ],
197      server_savepoint=server_savepoint_op,
198      client_tflite_graph_bytes=b'tflite-graph',
199      version=1)
200  plan.client_graph_bytes.Pack(client_graph.as_graph_def())
201  plan.server_graph_bytes.Pack(server_graph.as_graph_def())
202  plan.tensorflow_config_proto.Pack(config_proto)
203  return plan
204
205
206def create_checkpoint(tensor_name=b'test'):
207  """Creates a test initial checkpoint."""
208  return test_utils.create_checkpoint({CHECKPOINT_TENSOR_NAME: tensor_name})
209
210
211class PlanUtilsTest(absltest.TestCase):
212
213  def test_session_enter_exit(self):
214    self.assertIsNone(tf.compat.v1.get_default_session())
215    with plan_utils.Session(create_plan(), create_checkpoint()):
216      self.assertIsNotNone(tf.compat.v1.get_default_session())
217    self.assertIsNone(tf.compat.v1.get_default_session())
218
219  def test_session_without_phase(self):
220    plan = create_plan()
221    plan.ClearField('phase')
222    with self.assertRaises(ValueError):
223      plan_utils.Session(plan, create_checkpoint())
224
225  def test_session_without_server_phase(self):
226    plan = create_plan()
227    plan.phase[0].ClearField('server_phase')
228    with self.assertRaises(ValueError):
229      plan_utils.Session(plan, create_checkpoint())
230
231  def test_session_with_multiple_phases(self):
232    plan = create_plan()
233    plan.phase.append(plan.phase[0])
234    with self.assertRaises(ValueError):
235      plan_utils.Session(plan, create_checkpoint())
236
237  def test_session_client_plan(self):
238    plan = create_plan()
239    with plan_utils.Session(plan, create_checkpoint()) as session:
240      self.assertEqual(
241          plan_pb2.ClientOnlyPlan.FromString(session.client_plan),
242          plan_pb2.ClientOnlyPlan(
243              phase=plan.phase[0].client_phase,
244              graph=plan.client_graph_bytes.value,
245              tflite_graph=plan.client_tflite_graph_bytes,
246              tensorflow_config_proto=plan.tensorflow_config_proto))
247
248  def test_session_client_plan_without_tensorflow_config(self):
249    plan = create_plan()
250    plan.ClearField('tensorflow_config_proto')
251    with plan_utils.Session(plan, create_checkpoint()) as session:
252      self.assertEqual(
253          plan_pb2.ClientOnlyPlan.FromString(session.client_plan),
254          plan_pb2.ClientOnlyPlan(
255              phase=plan.phase[0].client_phase,
256              graph=plan.client_graph_bytes.value,
257              tflite_graph=plan.client_tflite_graph_bytes))
258
259  def test_session_client_plan_without_tflite_graph(self):
260    plan = create_plan()
261    plan.ClearField('client_tflite_graph_bytes')
262    with plan_utils.Session(plan, create_checkpoint()) as session:
263      self.assertEqual(
264          plan_pb2.ClientOnlyPlan.FromString(session.client_plan),
265          plan_pb2.ClientOnlyPlan(
266              phase=plan.phase[0].client_phase,
267              graph=plan.client_graph_bytes.value,
268              tensorflow_config_proto=plan.tensorflow_config_proto))
269
270  def test_session_client_checkpoint(self):
271    expected = b'test-client-checkpoint'
272    with plan_utils.Session(
273        create_plan(),
274        test_utils.create_checkpoint({CHECKPOINT_TENSOR_NAME: expected
275                                     })) as session:
276      self.assertEqual(
277          session.client_checkpoint,
278          expected + b' ' + next(iter(session.slices)).encode(),
279      )
280
281  def test_session_client_checkpoint_without_server_savepoint(self):
282    plan = create_plan()
283    # If server_savepoint isn't set, the checkpoint shouldn't be loaded.
284    plan.ClearField('server_savepoint')
285    with plan_utils.Session(plan, create_checkpoint()) as session:
286      self.assertStartsWith(
287          session.client_checkpoint, DEFAULT_INITIAL_CHECKPOINT + b' '
288      )
289
290  def test_session_finalize(self):
291    with tempfile.NamedTemporaryFile('r') as tmpfile:
292      with plan_utils.Session(create_plan(tmpfile.name),
293                              create_checkpoint()) as session:
294        checkpoint = session.finalize(
295            test_utils.create_checkpoint({INTERMEDIATE_TENSOR_NAME: 3}))
296      self.assertSequenceEqual(
297          tmpfile.read().splitlines(),
298          [
299              'server_savepoint/before_restore',
300              'server_savepoint/restore',
301              'server_savepoint/after_restore',
302              'phase_init',
303              'write_client_init/before_save',
304              'write_client_init/save',
305              'write_client_init/after_save',
306          ]
307          + ['slices/initialize', 'slices/save_slice'] * NUM_SLICES
308          + [
309              'read_intermediate_update/before_restore',
310              'read_intermediate_update/restore',
311              'read_intermediate_update/after_restore',
312              'apply_aggregated_updates',
313              'server_savepoint/before_save',
314              'server_savepoint/save',
315              'server_savepoint/after_save',
316          ],
317      )
318
319    result = test_utils.read_tensor_from_checkpoint(checkpoint,
320                                                    FINAL_TENSOR_NAME, tf.int32)
321    # The value should be propagated from the intermediate aggregate.
322    self.assertEqual(result, 3)
323
324  def test_session_with_tensorflow_error(self):
325    plan = create_plan()
326    plan.phase[0].server_phase.phase_init_op = 'does-not-exist'
327    with self.assertRaises(ValueError):
328      plan_utils.Session(plan, create_checkpoint())
329
330  def test_session_slices(self):
331    with plan_utils.Session(create_plan(), create_checkpoint()) as session:
332      # The served_at_id should match the value in the client checkpoint.
333      served_at_id = session.client_checkpoint.split(b' ')[1].decode()
334      self.assertSameElements(session.slices.keys(), [served_at_id])
335      self.assertListEqual(
336          session.slices[served_at_id],
337          [str(i).encode() for i in range(NUM_SLICES)],
338      )
339
340  def test_session_without_slices(self):
341    plan = create_plan()
342    plan.phase[0].server_phase.write_client_init.ClearField(
343        'session_token_tensor_name'
344    )
345    with plan_utils.Session(plan, create_checkpoint()) as session:
346      self.assertEmpty(session.slices)
347
348
349if __name__ == '__main__':
350  absltest.main()
351