xref: /aosp_15_r20/external/federated-compute/fcp/artifact_building/federated_compute_plan_builder.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""A library responsible for building Federated Compute plans.
15
16This library builds TFF-backed plans, using the `MapReduceForm` object
17output by the TFF compiler pipeline.
18"""
19
20import collections
21from collections.abc import Callable, Iterable, Mapping, Sequence
22import enum
23from typing import Optional, TypeVar, Union
24
25import attr
26import tensorflow as tf
27import tensorflow_federated as tff
28
29from fcp.artifact_building import artifact_constants
30from fcp.artifact_building import checkpoint_type
31from fcp.artifact_building import checkpoint_utils
32from fcp.artifact_building import data_spec
33from fcp.artifact_building import graph_helpers
34from fcp.artifact_building import proto_helpers
35from fcp.artifact_building import tensor_utils
36from fcp.artifact_building import type_checks
37from fcp.artifact_building import variable_helpers
38from fcp.protos import plan_pb2
39from fcp.tensorflow import append_slices
40from fcp.tensorflow import delete_file
41
42SECURE_SUM_BITWIDTH_URI = 'federated_secure_sum_bitwidth'
43SECURE_SUM_URI = 'federated_secure_sum'
44SECURE_MODULAR_SUM_URI = 'federated_secure_modular_sum'
45
46
47class SecureAggregationTensorShapeError(Exception):
48  """Error raised when secagg tensors do not have fully defined shape."""
49
50
51@enum.unique
52class ClientPlanType(enum.Enum):
53  """Option adjusting client plan type during plan building.
54
55  Values:
56    TENSORFLOW: The default value. Uses a TF client graph for client
57      computation.
58    EXAMPLE_QUERY: Uses an example query containing client computation logic in
59      the provided example selector(s).
60  """
61
62  TENSORFLOW = 'tensorflow'
63  EXAMPLE_QUERY = 'example_query'
64
65
66# A type representing a potentially nested struct of integers.
67IntStruct = Union[
68    int,
69    Mapping[str, Union['IntStruct', int]],
70    Sequence[Union['IntStruct', int]],
71]
72
73
74def _compute_secagg_parameters(
75    mrf: tff.backends.mapreduce.MapReduceForm,
76) -> tuple[IntStruct, IntStruct, IntStruct]:
77  """Executes the TensorFlow logic that computes the SecAgg parameters.
78
79  This function makes use of `mrf.secure_sum_bitwidth`,
80  `mrf.secure_sum_max_input`, and `mrf.secure_modular_sum_modulus` to derive
81  the parameters needed for the SecAgg protocol variants.
82
83  Args:
84    mrf: An instance of `tff.backends.mapreduce.MapReduceForm`.
85
86  Returns:
87    A 3-tuple of `bitwidth`, `max_input` and `moduli` structures of parameters
88    for the associated SecAgg variants.
89  """
90  type_checks.check_type(mrf, tff.backends.mapreduce.MapReduceForm, name='mrf')
91  secagg_parameters = []
92  with tf.Graph().as_default() as g:
93    for name, computation in [
94        ('bitwidth', mrf.secure_sum_bitwidth),
95        ('max_input', mrf.secure_sum_max_input),
96        ('modulus', mrf.secure_modular_sum_modulus),
97    ]:
98      secagg_parameters.append(
99          graph_helpers.import_tensorflow(name, computation)
100      )
101  with tf.compat.v1.Session(graph=g) as sess:
102    flat_output = sess.run(fetches=tf.nest.flatten(secagg_parameters))
103    return tf.nest.pack_sequence_as(secagg_parameters, flat_output)
104
105
106# A side-channel through which one tensor is securely aggregated.
107@attr.s
108class SecAggSideChannel:
109  # The name of the tensor being aggregated in the client and server graphs.
110  tensor_name: str = attr.ib()
111  # A proto describing how the side-channel is to be aggregated.
112  side_channel_proto: plan_pb2.SideChannel = attr.ib()
113  # A placeholder tensor into which the sidechannel aggregation is filled.
114  placeholder: tf.Tensor = attr.ib()
115  # The variable to feed into the server graph.
116  update_var: tf.Variable = attr.ib()
117
118
119SecAggParam = TypeVar('SecAggParam')
120
121
122def _create_secagg_sidechannels(
123    intrinsic_name: str,
124    update_type: variable_helpers.AllowedTffTypes,
125    get_modulus_scheme: Callable[
126        [SecAggParam], plan_pb2.SideChannel.SecureAggregand
127    ],
128    params: list[SecAggParam],
129) -> list[SecAggSideChannel]:
130  """Returns `SecAggSideChannel`s for tensors aggregated with `intrinsic_name`.
131
132  This method also creates variables for the securely-aggregated tensors within
133  the current default graph using `create_vars_for_tff_type`.
134
135  Args:
136    intrinsic_name: The name of the intrinsic (e.g.
137      `federated_secure_sum_bitwidth`) with which the tensors in `update_type`
138      are being aggregated.
139    update_type: The TFF type representing a structure of all tensors being
140      aggregated with `intrinsic_name`.
141    get_modulus_scheme: A function which will get the modulus scheme being used.
142      This typically requires some additional per-tensor parameters which must
143      be supplied using `params`.
144    params: A list of arguments to pass to `set_modulus_scheme`. There must be
145      exactly one element in this list per tensor in `update_type`.
146
147  Returns:
148    A list of `SecAggSideChannel`s describing how to aggregate each tensor.
149  """
150  # For secure aggregation, we don't use a saver (but still store metadata in a
151  # CheckpointOp). Instead we create sidechannel tensors that get fed into the
152  # server graph.
153  update_vars = variable_helpers.create_vars_for_tff_type(
154      update_type, f'{intrinsic_name}_update'
155  )
156
157  # For tensors aggregated by secagg, we make sure the tensor names are aligned
158  # in both client and sever graph by getting the names from the same method.
159  tensor_names = variable_helpers.get_shared_secagg_tensor_names(
160      intrinsic_name, update_type
161  )
162  assert len(update_vars) == len(params) == len(tensor_names), (
163      'The length of update_vars, params and tensor_names for'
164      f' {{intrinsic_name}} should be all equal, but found: {len(update_vars)},'
165      f' {len(params)}, and {len(tensor_names)}.'
166  )
167
168  results = []
169  for param, update_var, tensor_name in zip(params, update_vars, tensor_names):
170    secure_aggregand = get_modulus_scheme(param)
171    secure_aggregand.dimension.extend(
172        plan_pb2.SideChannel.SecureAggregand.Dimension(size=d.value)
173        for d in update_var.shape.dims
174    )
175    secure_aggregand.dtype = update_var.dtype.base_dtype.as_datatype_enum
176    placeholder = tf.compat.v1.placeholder(
177        update_var.dtype, update_var.get_shape()
178    )
179    side_channel_proto = plan_pb2.SideChannel(
180        secure_aggregand=secure_aggregand, restore_name=placeholder.name
181    )
182    results.append(
183        SecAggSideChannel(
184            tensor_name=tensor_name,
185            side_channel_proto=side_channel_proto,
186            placeholder=placeholder,
187            update_var=update_var,
188        )
189    )
190  return results
191
192
193def _read_secagg_update_from_sidechannel_into_vars(
194    *,  # Require parameters to be named.
195    secagg_intermediate_update_vars: list[tf.Variable],
196    secure_sum_bitwidth_update_type: (variable_helpers.AllowedTffTypes),
197    bitwidths: list[int],
198    secure_sum_update_type: (variable_helpers.AllowedTffTypes),
199    max_inputs: list[int],
200    secure_modular_sum_update_type: (variable_helpers.AllowedTffTypes),
201    moduli: list[int],
202) -> plan_pb2.CheckpointOp:
203  """Creates the `read_secagg_update` op.
204
205  `read_secagg_update` is a `plan_pb2.CheckpointOp` and used to restore the
206  secagg tensors in server graph.
207
208  Args:
209    secagg_intermediate_update_vars: A list of variables to assign the
210      secagg_update_data in the `after_restore_op`.
211    secure_sum_bitwidth_update_type: The type of the tensors aggregated using
212      `bitwidth`-based secure sum.
213    bitwidths: The `bitwidth`s for the tensors that will be aggregated using
214      `bitwidth`-based secure summation.
215    secure_sum_update_type: The type of the tensors aggregated using
216      `max_input`-based secure sum.
217    max_inputs: The max_input`s for the tensors that will be aggregated using
218      `max_input`-based secure summation.
219    secure_modular_sum_update_type: The type of the tensors aggregated using
220      modular secure summation.
221    moduli: The `modulus`s for the tensors that will be aggregated using modular
222      secure summation.
223
224  Returns:
225    A `plan_pb2.CheckpointOp` which performs the `read_secagg_update`.
226  """
227  side_channels: list[SecAggSideChannel] = []
228
229  def _aggregand_for_bitwidth(bitwidth):
230    return plan_pb2.SideChannel.SecureAggregand(
231        quantized_input_bitwidth=bitwidth
232    )
233
234  side_channels += _create_secagg_sidechannels(
235      SECURE_SUM_BITWIDTH_URI,
236      secure_sum_bitwidth_update_type,
237      _aggregand_for_bitwidth,
238      bitwidths,
239  )
240
241  def _aggregand_for_max_input(max_input):
242    # Note the +1-- `max_input` is inclusive, so `base_modulus == max_input`
243    # would overflow maximum-valued inputs to zero.
244    base_modulus = max_input + 1
245    modulus_times_shard_size = (
246        plan_pb2.SideChannel.SecureAggregand.ModulusTimesShardSize(
247            base_modulus=base_modulus
248        )
249    )
250    return plan_pb2.SideChannel.SecureAggregand(
251        modulus_times_shard_size=modulus_times_shard_size
252    )
253
254  side_channels += _create_secagg_sidechannels(
255      SECURE_SUM_URI,
256      secure_sum_update_type,
257      _aggregand_for_max_input,
258      max_inputs,
259  )
260
261  def _aggregand_for_modulus(modulus):
262    fixed_modulus = plan_pb2.SideChannel.SecureAggregand.FixedModulus(
263        modulus=modulus
264    )
265    return plan_pb2.SideChannel.SecureAggregand(fixed_modulus=fixed_modulus)
266
267  side_channels += _create_secagg_sidechannels(
268      SECURE_MODULAR_SUM_URI,
269      secure_modular_sum_update_type,
270      _aggregand_for_modulus,
271      moduli,
272  )
273
274  # Operations assigning from sidechannel placeholders to update variables.
275  assign_placeholders_to_updates = []
276  # Operations assigning from update variables to the result variables.
277  assign_updates_to_intermediate = []
278  read_secagg_update = plan_pb2.CheckpointOp()
279  for intermediate_update_var, side_channel in zip(
280      secagg_intermediate_update_vars, side_channels
281  ):
282    assign_placeholders_to_updates.append(
283        side_channel.update_var.assign(side_channel.placeholder)
284    )
285    assign_updates_to_intermediate.append(
286        intermediate_update_var.assign(side_channel.update_var)
287    )
288    read_secagg_update.side_channel_tensors[side_channel.tensor_name].CopyFrom(
289        side_channel.side_channel_proto
290    )
291
292  read_secagg_update.before_restore_op = tf.group(
293      *(assign_placeholders_to_updates)
294  ).name
295  read_secagg_update.after_restore_op = tf.group(
296      *(assign_updates_to_intermediate)
297  ).name
298
299  return read_secagg_update
300
301
302def _merge_secagg_vars(
303    secure_sum_bitwidth_update_type: tff.Type,
304    secure_sum_update_type: tff.Type,
305    secure_modular_sum_update_type: tff.Type,
306    flattened_moduli: list[int],
307    variables: list[tf.Variable],
308    tensors: list[tf.Variable],
309) -> list[tf.Operation]:
310  """Generates a set of ops to `merge` secagg `tensors` into `variables`."""
311  if len(variables) != len(tensors):
312    raise ValueError(
313        'Expected an equal number of variables and tensors, but found '
314        f'{len(variables)} variables and {len(tensors)} tensors.'
315    )
316  num_simple_add_vars = len(
317      tff.structure.flatten(
318          tff.to_type([
319              secure_sum_bitwidth_update_type,
320              secure_sum_update_type,
321          ])
322      )
323  )
324  num_modular_add_vars = len(
325      tff.structure.flatten(secure_modular_sum_update_type)
326  )
327  # There must be one variable and tensor for each tensor in the secure update
328  # types.
329  num_vars_from_types = num_simple_add_vars + num_modular_add_vars
330  if num_vars_from_types != len(variables):
331    raise ValueError(
332        'Expected one variable for each leaf element of the secagg update, but '
333        f'found {len(variables)} variables and {num_vars_from_types} leaf '
334        'elements in the following types:\n'
335        f'secure_sum_bitwidth_update_type: {secure_sum_bitwidth_update_type}\n'
336        f'secure_sum_update_type: {secure_sum_update_type}\n'
337        f'secure_modular_sum_update_type: {secure_modular_sum_update_type}\n'
338    )
339  if num_modular_add_vars != len(flattened_moduli):
340    raise ValueError(
341        'Expected one modulus for each leaf element of the secure modular sum '
342        f'update type. Found {len(flattened_moduli)} moduli and '
343        f'{num_modular_add_vars} leaf elements in the secure modular sum '
344        f'update type:\n{secure_modular_sum_update_type}'
345    )
346  # Add `tensors` to `vars`, using simple addition for the first
347  # `num_secagg_simple_add_vars` variables and modular addition for the rest
348  # (those coming from `secure_modular_sum`).
349  ops = []
350  simple_add_vars = variables[:num_simple_add_vars]
351  simple_add_tensors = tensors[:num_simple_add_vars]
352  for variable, tensor in zip(simple_add_vars, simple_add_tensors):
353    ops.append(variable.assign_add(tensor))
354  modular_add_vars = variables[num_simple_add_vars:]
355  modular_add_tensors = tensors[num_simple_add_vars:]
356  for modulus, (variable, tensor) in zip(
357      flattened_moduli, zip(modular_add_vars, modular_add_tensors)
358  ):
359    new_sum = tf.math.add(variable.read_value(), tensor)
360    modular_sum = tf.math.floormod(new_sum, modulus)
361    ops.append(variable.assign(tf.reshape(modular_sum, tf.shape(variable))))
362  return ops
363
364
365def _build_server_graphs_from_distribute_aggregate_form(
366    daf: tff.backends.mapreduce.DistributeAggregateForm,
367    is_broadcast_empty: bool,
368    grappler_config: tf.compat.v1.ConfigProto,
369) -> tuple[
370    tf.compat.v1.GraphDef, tf.compat.v1.GraphDef, plan_pb2.ServerPhaseV2
371]:
372  """Generates the server plan components based on DistributeAggregateForm.
373
374  Derives the pre-broadcast, aggregation, and post-aggregation logical
375  components in the ServerPhaseV2 message that will be executed on the server.
376  The pre-broadcast and post-aggregation components are to be executed with a
377  single TF sess.run call using the corresponding GraphDef. The aggregation
378  component is to be executed natively (i.e. not using TensorFlow) according to
379  the aggregation messages contained in the ServerPhaseV2 message.
380
381  Args:
382    daf: An instance of `tff.backends.mapreduce.DistributeAggregateForm`.
383    is_broadcast_empty: A boolean indicating whether the broadcasted value from
384      the server is expected to be empty based on the DistributeAggregateForm,
385      in which case the server should broadcast a placeholder tf.int32 tensor as
386      empty checkpoints are not supported.
387    grappler_config: The config specifying Grappler optimizations for TFF-
388      generated graphs.
389
390  Returns:
391    A tuple containing the server_prepare GraphDef, the server_result GraphDef,
392    and the ServerPhaseV2 message.
393  """
394  # Generate the TensorFlow graph needed to execute the server_prepare step,
395  # including reading input checkpoints and writing output checkpoints.
396  server_prepare_input_tensors = []
397  server_prepare_target_nodes = []
398  with tf.Graph().as_default() as server_prepare_graph:
399    # Create the placeholders for the input and output filenames needed by
400    # the server_prepare step.
401    server_prepare_server_state_input_filepath_placeholder = (
402        tf.compat.v1.placeholder(
403            name='server_state_input_filepath', shape=(), dtype=tf.string
404        )
405    )
406    server_prepare_output_filepath_placeholder = tf.compat.v1.placeholder(
407        name='server_prepare_output_filepath', shape=(), dtype=tf.string
408    )
409    server_prepare_intermediate_state_output_filepath_placeholder = (
410        tf.compat.v1.placeholder(
411            name='server_intermediate_state_output_filepath',
412            shape=(),
413            dtype=tf.string,
414        )
415    )
416    server_prepare_input_tensors.extend([
417        server_prepare_server_state_input_filepath_placeholder,
418        server_prepare_output_filepath_placeholder,
419        server_prepare_intermediate_state_output_filepath_placeholder,
420    ])
421
422    # Restore the server state.
423    server_state_type = daf.server_prepare.type_signature.parameter
424    server_state_vars = variable_helpers.create_vars_for_tff_type(
425        server_state_type, name='server'
426    )
427    server_state_tensor_specs = tf.nest.map_structure(
428        variable_helpers.tensorspec_from_var, server_state_vars
429    )
430    server_state = checkpoint_utils.restore_tensors_from_savepoint(
431        server_state_tensor_specs,
432        server_prepare_server_state_input_filepath_placeholder,
433    )
434
435    # TODO(team): Add support for federated select slice generation.
436
437    # Perform the server_prepare step.
438    prepared_values, intermediate_state_values = (
439        graph_helpers.import_tensorflow(
440            'server_prepare',
441            tff.framework.ConcreteComputation.from_building_block(
442                tff.backends.mapreduce.consolidate_and_extract_local_processing(
443                    daf.server_prepare.to_building_block(), grappler_config
444                )
445            ),
446            server_state,
447            split_outputs=True,
448        )
449    )
450
451    # Create checkpoints storing the broadcast values and intermediate server
452    # state. If there is no broadcast value, create a checkpoint containing a
453    # placeholder tf.int32 constant since empty broadcasts are not supported.
454    # If there is no intermediate server state, don't create an intermediate
455    # server state checkpoint.
456    save_tensor_names = variable_helpers.variable_names_from_type(
457        daf.server_prepare.type_signature.result[0], name='client'
458    )
459    save_tensors = prepared_values
460    if is_broadcast_empty:
461      save_tensor_names = variable_helpers.variable_names_from_type(
462          tff.StructType([tf.int32]), name='client'
463      )
464      save_tensors = [tf.constant(0, tf.int32)]
465    prepared_values_save_op = tensor_utils.save(
466        filename=server_prepare_output_filepath_placeholder,
467        tensor_names=save_tensor_names,
468        tensors=save_tensors,
469        name='save_prepared_values_tensors',
470    )
471    server_prepare_target_nodes.append(prepared_values_save_op.name)
472
473    intermediate_state_empty = (
474        isinstance(daf.server_prepare.type_signature.result[1], tff.StructType)
475        and not daf.server_prepare.type_signature.result[1]
476    )
477    if not intermediate_state_empty:
478      intermediate_state_values_save_op = tensor_utils.save(
479          filename=server_prepare_intermediate_state_output_filepath_placeholder,
480          tensor_names=variable_helpers.variable_names_from_type(
481              daf.server_prepare.type_signature.result[1], 'intermediate_state'
482          ),
483          tensors=intermediate_state_values,
484          name='save_intermediate_state_values_tensors',
485      )
486      server_prepare_target_nodes.append(intermediate_state_values_save_op.name)
487
488  # Build aggregations.
489  # The client_to_server_aggregation computation is guaranteed to conform to
490  # a specific structure. It is a lambda computation whose result block contains
491  # locals that are exclusively aggregation-type intrinsics.
492  aggregations_bb = daf.client_to_server_aggregation.to_building_block()
493  aggregations_bb.check_lambda()
494  aggregations_bb.result.check_block()  # pytype: disable=attribute-error
495
496  # Get lists of the TensorSpecProtos for the inputs and outputs of all
497  # intrinsic calls. These lists are formatted such that the ith entry
498  # represents the TensorSpecProtos for the ith intrinsic in the aggregation
499  # computation. Since intrinsics may have one or more args, the ith entry in
500  # the input TensorSpecProto list is itself a list, where the jth entry
501  # represents the TensorSpecProtos corresponding to the jth argument of the
502  # ith intrinsic.
503  grouped_input_tensor_specs = variable_helpers.get_grouped_input_tensor_specs_for_aggregations(
504      aggregations_bb,
505      artifact_constants.AGGREGATION_INTRINSIC_ARG_SELECTION_INDEX_TO_NAME_DICT,
506  )
507  grouped_output_tensor_specs = (
508      variable_helpers.get_grouped_output_tensor_specs_for_aggregations(
509          aggregations_bb
510      )
511  )
512  assert len(grouped_input_tensor_specs) == len(grouped_output_tensor_specs)
513
514  intrinsic_uris = [
515      local_value.function.intrinsic_def().uri
516      for _, local_value in aggregations_bb.result.locals  # pytype: disable=attribute-error
517  ]
518  assert len(intrinsic_uris) == len(grouped_output_tensor_specs)
519
520  # Each intrinsic input arg can be a struct or even a nested struct, which
521  # requires the intrinsic to be applied independently to each element (e.g. a
522  # tff.federated_sum call applied to a struct will result in a federated_sum
523  # aggregation message for each element of the struct). Note that elements of
524  # structs can themselves be multi-dimensional tensors. When an intrinsic call
525  # has multiple args with mismatching structure (e.g. a federated_weighted_mean
526  # intrinsic applied to a 2D struct value arg and scalar weight arg), some args
527  # will need to be "scaled up" via repetition to match the args with the
528  # "largest" structure.
529  aggregations = []
530  for intrinsic_index, (input_tensor_specs, output_tensor_specs) in enumerate(
531      zip(grouped_input_tensor_specs, grouped_output_tensor_specs)
532  ):
533    # Generate the aggregation messages for this intrinsic call.
534    max_input_struct_length = max([len(x) for x in input_tensor_specs])
535    max_struct_length = max(max_input_struct_length, len(output_tensor_specs))
536    for i in range(max_struct_length):
537      intrinsic_args = []
538      for j, _ in enumerate(input_tensor_specs):
539        # Scale up any "smaller" structure args by reusing their last element.
540        tensor_spec = input_tensor_specs[j][
541            min(i, len(input_tensor_specs[j]) - 1)
542        ]
543        if tensor_spec.name.startswith('update'):
544          intrinsic_args.append(
545              plan_pb2.ServerAggregationConfig.IntrinsicArg(
546                  input_tensor=tensor_spec.experimental_as_proto()
547              )
548          )
549        else:
550          intrinsic_args.append(
551              plan_pb2.ServerAggregationConfig.IntrinsicArg(
552                  state_tensor=tensor_spec.experimental_as_proto()
553              )
554          )
555      aggregations.append(
556          plan_pb2.ServerAggregationConfig(
557              intrinsic_uri=intrinsic_uris[intrinsic_index],
558              intrinsic_args=intrinsic_args,
559              # Scale up the output structure by reusing the last element if
560              # needed.
561              output_tensors=[
562                  output_tensor_specs[
563                      min(i, len(output_tensor_specs) - 1)
564                  ].experimental_as_proto()
565              ],
566          )
567      )
568
569  # Generate the TensorFlow graph needed to execute the server_result step,
570  # including reading input checkpoints, writing output checkpoints, and
571  # generating output tensors.
572  server_result_input_tensors = []
573  server_result_output_tensors = []
574  server_result_target_nodes = []
575  with tf.Graph().as_default() as server_result_graph:
576    # Create the placeholders for the input and output filenames needed by
577    # the server_result step.
578    server_result_intermediate_state_input_filepath_placeholder = (
579        tf.compat.v1.placeholder(
580            name='server_intermediate_state_input_filepath',
581            shape=(),
582            dtype=tf.string,
583        )
584    )
585    server_result_aggregate_result_input_filepath_placeholder = (
586        tf.compat.v1.placeholder(
587            name='aggregate_result_input_filepath', shape=(), dtype=tf.string
588        )
589    )
590    server_result_server_state_output_filepath_placeholder = (
591        tf.compat.v1.placeholder(
592            name='server_state_output_filepath', shape=(), dtype=tf.string
593        )
594    )
595    server_result_input_tensors.extend([
596        server_result_intermediate_state_input_filepath_placeholder,
597        server_result_aggregate_result_input_filepath_placeholder,
598        server_result_server_state_output_filepath_placeholder,
599    ])
600
601    # Restore the intermediate server state.
602    intermediate_state = []
603    if not intermediate_state_empty:
604      intermediate_state_type = daf.server_result.type_signature.parameter[0]
605      intermediate_state_vars = variable_helpers.create_vars_for_tff_type(
606          intermediate_state_type, 'intermediate_state'
607      )
608      intermediate_state_tensor_specs = tf.nest.map_structure(
609          variable_helpers.tensorspec_from_var, intermediate_state_vars
610      )
611      intermediate_state = checkpoint_utils.restore_tensors_from_savepoint(
612          intermediate_state_tensor_specs,
613          server_result_intermediate_state_input_filepath_placeholder,
614      )
615
616    # Restore the aggregation results.
617    aggregate_result_type = tff.StructType(
618        [daf.server_result.type_signature.parameter[1]]
619    )
620    aggregate_result_vars = variable_helpers.create_vars_for_tff_type(
621        aggregate_result_type, 'intermediate_update'
622    )
623    aggregate_result_tensor_specs = tf.nest.map_structure(
624        variable_helpers.tensorspec_from_var, aggregate_result_vars
625    )
626    aggregate_result = checkpoint_utils.restore_tensors_from_savepoint(
627        aggregate_result_tensor_specs,
628        server_result_aggregate_result_input_filepath_placeholder,
629    )
630
631    # Perform the server_result step.
632    server_state_values, server_output_values = graph_helpers.import_tensorflow(
633        'server_result',
634        tff.framework.ConcreteComputation.from_building_block(
635            tff.backends.mapreduce.consolidate_and_extract_local_processing(
636                daf.server_result.to_building_block(), grappler_config
637            )
638        ),
639        (intermediate_state, aggregate_result),
640        split_outputs=True,
641    )
642
643    # Create checkpoints storing the updated server state.
644    server_state_save_op = tensor_utils.save(
645        filename=server_result_server_state_output_filepath_placeholder,
646        tensor_names=variable_helpers.variable_names_from_type(
647            daf.server_result.type_signature.result[0], 'server'
648        ),
649        tensors=server_state_values,
650        name='save_server_state_tensors',
651    )
652    server_result_target_nodes.append(server_state_save_op.name)
653
654    # Generate the output TensorSpecProtos for the server metrics if some exist.
655    server_output_empty = (
656        isinstance(daf.server_result.type_signature.result[1], tff.StructType)
657        and not daf.server_result.type_signature.result[1]
658    )
659    if not server_output_empty:
660      metric_names = variable_helpers.variable_names_from_type(
661          daf.server_result.type_signature.result[1], 'server'
662      )
663      metric_tensors = [
664          tf.identity(tensor, name)
665          for tensor, name in zip(server_output_values, metric_names)
666      ]
667      for metric in metric_tensors:
668        server_result_output_tensors.append(
669            proto_helpers.make_tensor_spec_from_tensor(
670                metric
671            ).experimental_as_proto()
672        )
673
674  # Create the TensorflowSpec messages for the pre-broadcast (server_prepare)
675  # and post-aggregation (server_result) steps.
676  tensorflow_spec_prepare = plan_pb2.TensorflowSpec(
677      input_tensor_specs=[
678          proto_helpers.make_tensor_spec_from_tensor(t).experimental_as_proto()
679          for t in server_prepare_input_tensors
680      ],
681      target_node_names=server_prepare_target_nodes,
682  )
683  tensorflow_spec_result = plan_pb2.TensorflowSpec(
684      input_tensor_specs=[
685          proto_helpers.make_tensor_spec_from_tensor(t).experimental_as_proto()
686          for t in server_result_input_tensors
687      ],
688      output_tensor_specs=server_result_output_tensors,
689      target_node_names=server_result_target_nodes,
690  )
691
692  # Create the IORouter messages for the pre-broadcast (server_prepare) and
693  # post-aggregation (server_result) steps.
694  server_prepare_io_router = plan_pb2.ServerPrepareIORouter(
695      prepare_server_state_input_filepath_tensor_name=server_prepare_server_state_input_filepath_placeholder.name,
696      prepare_output_filepath_tensor_name=server_prepare_output_filepath_placeholder.name,
697      prepare_intermediate_state_output_filepath_tensor_name=server_prepare_intermediate_state_output_filepath_placeholder.name,
698  )
699  server_result_io_router = plan_pb2.ServerResultIORouter(
700      result_intermediate_state_input_filepath_tensor_name=server_result_intermediate_state_input_filepath_placeholder.name,
701      result_aggregate_result_input_filepath_tensor_name=server_result_aggregate_result_input_filepath_placeholder.name,
702      result_server_state_output_filepath_tensor_name=server_result_server_state_output_filepath_placeholder.name,
703  )
704
705  server_phase_v2 = plan_pb2.ServerPhaseV2(
706      tensorflow_spec_prepare=tensorflow_spec_prepare,
707      prepare_router=server_prepare_io_router,
708      aggregations=aggregations,
709      tensorflow_spec_result=tensorflow_spec_result,
710      result_router=server_result_io_router,
711  )
712
713  return (
714      server_prepare_graph.as_graph_def(),
715      server_result_graph.as_graph_def(),
716      server_phase_v2,
717  )
718
719
720def _build_server_graph(
721    mrf: tff.backends.mapreduce.MapReduceForm,
722    broadcast_tff_type: variable_helpers.AllowedTffTypes,
723    is_broadcast_empty: bool,
724    flattened_bitwidths: list[int],
725    flattened_max_inputs: list[int],
726    flattened_moduli: list[int],
727    write_metrics_to_checkpoint: bool = True,
728    additional_checkpoint_metadata_var_fn: Optional[
729        Callable[[tff.StructType, tff.StructType, bool], list[tf.Variable]]
730    ] = None,
731    experimental_client_update_format: checkpoint_type.CheckpointFormatType = checkpoint_type.CheckpointFormatType.TF1_SAVE_SLICES,
732) -> tuple[
733    tf.compat.v1.GraphDef,
734    plan_pb2.CheckpointOp,
735    plan_pb2.ServerPhase,
736    list[tf.TensorSpec],
737]:
738  """Builds the `tf.Graph` that will run on the server.
739
740  Args:
741    mrf: A `MapReduceForm` object containing the different computations to
742      combine into a single server graph.
743    broadcast_tff_type: A `tff.Type` object that specifies the tensors in the
744      model that are broadcasted and aggregated.
745    is_broadcast_empty: boolean indicating whether the broadcasted value from
746      the server was initially empty.
747    flattened_bitwidths: The `bitwidth`s for the tensors that will be aggregated
748      using `bitwidth`-based secure summation.
749    flattened_max_inputs: The max_input`s for the tensors that will be
750      aggregated using `max_input`-based secure summation.
751    flattened_moduli: The `modulus`s for the tensors that will be aggregated
752      using modular secure summation.
753    write_metrics_to_checkpoint: If False, revert to legacy behavior where
754      metrics values were handled by post-processing separate from the outputted
755      checkpoint. Regardless, they will additionally continue to be written to
756      recordio and accumulator checkpoints as defined by the Plan proto.
757    additional_checkpoint_metadata_var_fn: An optional method that takes in a
758      server state type, a server metrics type, and a boolean determining
759      whether to revert to legacy metrics behavior to produce additional
760      metadata variables.
761    experimental_client_update_format: Determines how the client update will be
762      interpreted. The value has to match experimental_checkpoint_write argument
763      of the  _build_client_graph_with_tensorflow_spec call.
764
765  Returns:
766    A `tuple` containing the following (in order):
767      - A server `tf.GraphDef`,
768      - A server checkpoint,
769      - A server phase proto message, and
770      - A list of `tf.TensorSpec`s for the broadcasted values.
771  """
772  (
773      simpleagg_update_type,
774      secure_sum_bitwidth_update_type,
775      secure_sum_update_type,
776      secure_modular_sum_update_type,
777  ) = mrf.work.type_signature.result
778  with tf.Graph().as_default() as server_graph:
779    # Creates all server-side variables and savepoints for both the coordinator
780    # and the intermediate aggregators.
781    # server_state_type will be a SERVER-placed federated type.
782    server_state_type, server_metrics_type = mrf.type_signature.result
783    assert server_state_type.is_federated(), server_state_type
784    assert server_state_type.placement == tff.SERVER, server_state_type
785    # server_metrics_type can be a tff.FederatedType or a structure containing
786    # tff.FederatedTypes.
787    if isinstance(server_metrics_type, tff.FederatedType):
788      # We need to check for server metrics without the placement so
789      # tff.structure.flatten works correctly.
790      has_server_metrics = bool(
791          tff.structure.flatten(server_metrics_type.member)
792      )
793    else:
794      has_server_metrics = bool(tff.structure.flatten(server_metrics_type))
795    if isinstance(server_metrics_type, tff.TensorType) or (
796        isinstance(server_metrics_type, tff.FederatedType)
797        and isinstance(server_metrics_type.member, tff.TensorType)
798    ):
799      # Single tensor; must be wrapped inside of a NamedTuple for proper
800      # variable initialization.
801      server_metrics_type = tff.StructType([server_metrics_type])
802    (
803        server_state_vars,
804        server_metrics_vars,
805        metadata_vars,
806        server_savepoint,
807    ) = checkpoint_utils.create_server_checkpoint_vars_and_savepoint(
808        server_state_type=server_state_type,
809        server_metrics_type=server_metrics_type,
810        write_metrics_to_checkpoint=write_metrics_to_checkpoint,
811        additional_checkpoint_metadata_var_fn=(
812            additional_checkpoint_metadata_var_fn
813        ),
814    )
815
816    # TODO(team): Switch to `tf.save()` in lieu of savers to avoid the
817    # need to create client variables on the server.
818    client_vars_on_server, write_client = (
819        checkpoint_utils.create_state_vars_and_savepoint(
820            broadcast_tff_type, 'client'
821        )
822    )
823
824    secure_sum_update_types = [
825        secure_sum_bitwidth_update_type,
826        secure_sum_update_type,
827        secure_modular_sum_update_type,
828    ]
829    combined_intermediate_update_type = tff.StructType(
830        [mrf.zero.type_signature.result] + secure_sum_update_types
831    )
832
833    combined_intermediate_update_vars, write_intermediate_update = (
834        checkpoint_utils.create_state_vars_and_savepoint(
835            combined_intermediate_update_type, 'intermediate_update'
836        )
837    )
838    num_simpleagg_vars = len(combined_intermediate_update_vars) - len(
839        tff.structure.flatten(tff.to_type(secure_sum_update_types))
840    )
841    intermediate_update_vars = combined_intermediate_update_vars[
842        :num_simpleagg_vars
843    ]
844    secagg_intermediate_update_vars = combined_intermediate_update_vars[
845        num_simpleagg_vars:
846    ]
847
848    read_secagg_update = _read_secagg_update_from_sidechannel_into_vars(
849        secagg_intermediate_update_vars=secagg_intermediate_update_vars,
850        secure_sum_bitwidth_update_type=secure_sum_bitwidth_update_type,
851        bitwidths=flattened_bitwidths,
852        secure_sum_update_type=secure_sum_update_type,
853        max_inputs=flattened_max_inputs,
854        secure_modular_sum_update_type=secure_modular_sum_update_type,
855        moduli=flattened_moduli,
856    )
857
858    combined_aggregated_update_vars, write_accumulators = (
859        checkpoint_utils.create_state_vars_and_savepoint(
860            combined_intermediate_update_type, 'aggregated_update'
861        )
862    )
863    aggregated_update_vars = combined_aggregated_update_vars[
864        :num_simpleagg_vars
865    ]
866    secagg_aggregated_update_vars = combined_aggregated_update_vars[
867        num_simpleagg_vars:
868    ]
869
870    # Throws in the initializer for all state variables, to be executed prior
871    # to restoring the savepoint. Run this variable initializer prior to
872    # restoring from the savepoint to allow the vars to be overwritten by the
873    # savepoint in this case, and so they do not get re-executed after being
874    # overwritten. Also include the metrics vars here in case the execution
875    # environment wants to read those in.
876    server_vars_initializer = tf.compat.v1.variables_initializer(
877        server_state_vars + metadata_vars + server_metrics_vars,
878        'initialize_server_state_and_non_state_vars',
879    )
880    server_savepoint.before_restore_op = server_vars_initializer.name
881
882    # In graph mode, TensorFlow does not allow creating a
883    # `tf.compat.v1.train.Saver` when there are no variables. As a result,
884    # calling `create_state_vars_and_savepoint` below will fail when there are
885    # no SimpleAgg variables (e.g., all results are aggregated via SecAgg). In
886    # this case, there are no client checkpoints, and hence, no need to populate
887    # the `read_update` field.
888    if num_simpleagg_vars > 0:
889      # Run the initializer for update vars prior to restoring the client update
890      update_vars, read_update = (
891          checkpoint_utils.create_state_vars_and_savepoint(
892              simpleagg_update_type, artifact_constants.UPDATE
893          )
894      )
895      update_vars_initializer = tf.compat.v1.variables_initializer(
896          update_vars, 'initialize_update_vars'
897      )
898      if (
899          experimental_client_update_format
900          == checkpoint_type.CheckpointFormatType.APPEND_SLICES_MERGE_READ
901      ):
902        graph = tf.compat.v1.get_default_graph()
903        checkpoint_pl = graph.get_tensor_by_name(
904            read_update.saver_def.filename_tensor_name
905        )
906        merge_checkpoint_slices = append_slices.merge_appended_slices(
907            checkpoint_pl, 'merge_checkpoint_slices'
908        )
909        init_merge = tf.group(update_vars_initializer, merge_checkpoint_slices)
910        read_update.before_restore_op = init_merge.name
911      else:
912        read_update.before_restore_op = update_vars_initializer.name
913    else:
914      # Create a empty list for `update_vars` when there are no SimpleAgg
915      # variables, to be compatible with the `accumulated_values` defined below.
916      update_vars = []
917
918    # Copy the intermediate aggregator's update saver for use on coordinator.
919    read_intermediate_update = plan_pb2.CheckpointOp()
920    read_intermediate_update.CopyFrom(write_intermediate_update)
921
922    # Condition all the remaining logic on the variable initializers, since
923    # intermediate aggregators are supposed to be stateless (no savepoint, and
924    # therefore no `before_restore_op`, either).
925    with tf.control_dependencies(
926        [
927            tf.compat.v1.variables_initializer(
928                (intermediate_update_vars + aggregated_update_vars),
929                'initialize_accumulator_vars',
930            )
931        ]
932    ):
933      # Embeds the `zero` logic and hooks it up to `after_restore_op` of
934      # server's checkpointed state (shared between the coordinator and the
935      # intermediate aggregators). The zeros get assigned to
936      # `intermediate_update_vars` and to the `aggregated_update_vars` at the
937      # very beginning, right after restoring from `server_savepoint`.
938      zero_values = graph_helpers.import_tensorflow('zero', mrf.zero)
939      assign_zero_ops = tf.nest.map_structure(
940          lambda variable, value: variable.assign(value),
941          intermediate_update_vars,
942          zero_values,
943      ) + tf.nest.map_structure(
944          lambda variable, value: variable.assign(value),
945          aggregated_update_vars,
946          zero_values,
947      )
948
949    # Embeds the `prepare` logic, and hooks it up to `before_save_op` of
950    # client state (to be checkpointed and sent to the clients at the
951    # beginning of the round by the central coordinator).
952    with tf.control_dependencies(
953        [
954            tf.compat.v1.variables_initializer(
955                client_vars_on_server, 'initialize_client_vars_on_server'
956            )
957        ]
958    ):
959      # Configure the session token for `write_client` so that the `prepare`
960      # operation may be fed the callback ID for the `SaveSlices` op
961      # (necessary for plans containing `federated_select`).
962      write_client_session_token = tf.compat.v1.placeholder_with_default(
963          input='', shape=(), name='write_client_session_token'
964      )
965      prepared_values = graph_helpers.import_tensorflow(
966          'prepare',
967          mrf.prepare,
968          server_state_vars,
969          session_token_tensor=write_client_session_token,
970      )
971      if is_broadcast_empty:
972        # If the broadcast was empty, don't assigning the sample incoming
973        # tf.int32 to anything.
974        client_state_assign_ops = [tf.no_op()]
975      else:
976        client_state_assign_ops = tf.nest.map_structure(
977            lambda variable, tensor: variable.assign(tensor),
978            client_vars_on_server,
979            prepared_values,
980        )
981    write_client.before_save_op = tf.group(*client_state_assign_ops).name
982    write_client.session_token_tensor_name = write_client_session_token.name
983
984    # Embeds the `accumulate` logic, and hooks up the assignment of a client
985    # update to the intermediate update to `aggregate_into_accumulators_op`.
986    accumulated_values = graph_helpers.import_tensorflow(
987        'accumulate', mrf.accumulate, (intermediate_update_vars, update_vars)
988    )
989    intermediate_update_assign_ops = tf.nest.map_structure(
990        lambda variable, tensor: variable.assign(tensor),
991        intermediate_update_vars,
992        accumulated_values,
993    )
994    aggregate_into_accumulators_op = tf.group(
995        *intermediate_update_assign_ops
996    ).name
997
998    secagg_aggregated_update_init = tf.compat.v1.variables_initializer(
999        secagg_aggregated_update_vars
1000    )
1001
1002    # Reset the accumulators in `phase_init_op`, after variable initializers
1003    # and after restoring from the savepoint.
1004    phase_init_op = tf.group(
1005        *(assign_zero_ops + [secagg_aggregated_update_init])
1006    ).name
1007
1008    # Embeds the `merge` logic, and hooks up the assignment of an intermediate
1009    # update to the top-level aggregate update at the coordinator to
1010    # `intermediate_aggregate_into_accumulators_op`.
1011    merged_values = graph_helpers.import_tensorflow(
1012        'merge', mrf.merge, (aggregated_update_vars, intermediate_update_vars)
1013    )
1014    aggregated_update_assign_ops = tf.nest.map_structure(
1015        lambda variable, tensor: variable.assign(tensor),
1016        aggregated_update_vars,
1017        merged_values,
1018    )
1019
1020    secagg_aggregated_update_ops = _merge_secagg_vars(
1021        secure_sum_bitwidth_update_type,
1022        secure_sum_update_type,
1023        secure_modular_sum_update_type,
1024        flattened_moduli,
1025        secagg_aggregated_update_vars,
1026        secagg_intermediate_update_vars,
1027    )
1028
1029    intermediate_aggregate_into_accumulators_op = tf.group(
1030        *(aggregated_update_assign_ops + secagg_aggregated_update_ops)
1031    ).name
1032
1033    # Embeds the `report` and `update` logic, and hooks up the assignments of
1034    # the results of the final update to the server state and metric vars, to
1035    # be triggered by `apply_aggregrated_updates_op`.
1036    simpleagg_reported_values = graph_helpers.import_tensorflow(
1037        'report', mrf.report, aggregated_update_vars
1038    )
1039
1040    # NOTE: In MapReduceForm, the `update` method takes in the simpleagg vars
1041    # and SecAgg vars as a tuple of two separate lists. However, here, as
1042    # above, we concatenate the simpleagg values and the secure values into a
1043    # single list. This mismatch is not a problem because the variables are all
1044    # flattened either way when traveling in and out of the tensorflow graph.
1045    combined_update_vars = (
1046        simpleagg_reported_values + secagg_aggregated_update_vars
1047    )
1048    new_server_state_values, server_metrics_values = (
1049        graph_helpers.import_tensorflow(
1050            artifact_constants.UPDATE,
1051            mrf.update,
1052            (server_state_vars, combined_update_vars),
1053            split_outputs=True,
1054        )
1055    )
1056
1057    assign_server_state_ops = tf.nest.map_structure(
1058        lambda variable, tensor: variable.assign(tensor),
1059        server_state_vars,
1060        new_server_state_values,
1061    )
1062    assign_non_state_ops = tf.nest.map_structure(
1063        lambda variable, value: variable.assign(value),
1064        server_metrics_vars,
1065        server_metrics_values,
1066    )
1067    all_assign_ops = assign_server_state_ops + assign_non_state_ops
1068    apply_aggregrated_updates_op = tf.group(*all_assign_ops).name
1069
1070    # Constructs the metadata for server metrics to be included in the plan.
1071    server_metrics = [
1072        proto_helpers.make_metric(v, artifact_constants.SERVER_STATE_VAR_PREFIX)
1073        for v in server_metrics_vars
1074    ]
1075
1076  server_phase_kwargs = collections.OrderedDict(
1077      phase_init_op=phase_init_op,
1078      write_client_init=write_client,
1079      read_aggregated_update=read_secagg_update,
1080      write_intermediate_update=write_intermediate_update,
1081      read_intermediate_update=read_intermediate_update,
1082      intermediate_aggregate_into_accumulators_op=(
1083          intermediate_aggregate_into_accumulators_op
1084      ),
1085      write_accumulators=write_accumulators,
1086      apply_aggregrated_updates_op=apply_aggregrated_updates_op,
1087      metrics=server_metrics,
1088  )
1089
1090  if num_simpleagg_vars > 0:
1091    # The `read_update` loads SimpleAgg updates from client checkpoints. The
1092    # `aggregate_into_accumulators_op` aggregates SimpleAgg data after loading
1093    # the client updates. No need to populate the two fields if there are no
1094    # SimpleAgg variables (e.g., if all results are aggregated via SecAgg).
1095    server_phase_kwargs['read_update'] = read_update
1096    server_phase_kwargs['aggregate_into_accumulators_op'] = (
1097        aggregate_into_accumulators_op
1098    )
1099
1100  server_phase = plan_pb2.ServerPhase(**server_phase_kwargs)
1101
1102  broadcasted_tensor_specs = tf.nest.map_structure(
1103      variable_helpers.tensorspec_from_var, client_vars_on_server
1104  )
1105  server_graph_def = server_graph.as_graph_def()
1106
1107  if write_metrics_to_checkpoint:
1108    server_graph_def = _redirect_save_saver_to_restore_saver_placeholder(
1109        server_graph_def
1110    )
1111
1112  return (
1113      server_graph_def,
1114      server_savepoint,
1115      server_phase,
1116      broadcasted_tensor_specs,
1117  )
1118
1119
1120def _redirect_save_saver_to_restore_saver_placeholder(
1121    graph_def: tf.compat.v1.GraphDef,
1122) -> tf.compat.v1.GraphDef:
1123  """Updates save Saver's savepoint to point to restore Saver's placeholder.
1124
1125  NOTE: mutates the GraphDef passed in and returns the mutated GraphDef.
1126
1127  When we created the server_savepoint Saver when we are outputting all of
1128  the metrics to the output checkpoint as well, we set different nodes for
1129  saving and restoring so that we could save state + metrics and restore
1130  just state. However, the only way to do so was to make two Savers and
1131  splice them together. This meant that the save and restore operations
1132  depend on two different placeholders for the checkpoint filename. To
1133  avoid server changes that pass the same checkpoint name in twice to both
1134  placeholders, we make a few changes to the server GraphDef so that the
1135  saving op connects back to the placeholder for the restore operation.
1136  Once this is done, the original save placeholder node will still exist in
1137  the graph, but it won't be used by any part of the graph that connects to
1138  an operation we care about.
1139
1140  Args:
1141    graph_def: A `tf.compat.v1.GraphDef` to mutate.
1142
1143  Returns:
1144    The mutated `tf.compat.v1.GraphDef` that was passed in as graph_def.
1145  """
1146  old_const_node = f'{checkpoint_utils.SAVE_SERVER_SAVEPOINT_NAME}/Const'
1147  new_const_node = (
1148      f'{artifact_constants.SERVER_STATE_VAR_PREFIX}_savepoint/Const'
1149  )
1150  nodes_to_change = [
1151      f'{checkpoint_utils.SAVE_SERVER_SAVEPOINT_NAME}/save',
1152      f'{checkpoint_utils.SAVE_SERVER_SAVEPOINT_NAME}/control_dependency',
1153      f'{checkpoint_utils.SAVE_SERVER_SAVEPOINT_NAME}/RestoreV2',
1154  ]
1155  num_changed_nodes = 0
1156  for node in graph_def.node:
1157    if node.name in nodes_to_change:
1158      input_index = 0
1159      for input_index, input_node in enumerate(node.input):
1160        if input_node == old_const_node:
1161          node.input[input_index] = new_const_node
1162          break
1163      assert input_index != len(
1164          node.input
1165      ), 'Missed input arg in saver GraphDef rewriting.'
1166      num_changed_nodes = num_changed_nodes + 1
1167    if num_changed_nodes == len(nodes_to_change):
1168      # Once we've changed all of the callsites, we stop.
1169      return graph_def
1170  return graph_def
1171
1172
1173def _build_client_graph_with_tensorflow_spec(
1174    client_work_comp: tff.Computation,
1175    dataspec,
1176    broadcasted_tensor_specs: Iterable[tf.TensorSpec],
1177    is_broadcast_empty: bool,
1178    *,
1179    experimental_checkpoint_write: checkpoint_type.CheckpointFormatType = checkpoint_type.CheckpointFormatType.TF1_SAVE_SLICES,
1180) -> tuple[tf.compat.v1.GraphDef, plan_pb2.ClientPhase]:
1181  """Builds the client graph and ClientPhase with TensorflowSpec populated.
1182
1183  This function builds a client phase with tensorflow specs proto.
1184
1185  Args:
1186    client_work_comp: A `tff.Computation` that represents the TensorFlow logic
1187      run on-device.
1188    dataspec: Either an instance of `data_spec.DataSpec` or a nested structure
1189      of these that matches the structure of the first element of the input to
1190      `client_work_comp`.
1191    broadcasted_tensor_specs: A list of `tf.TensorSpec` containing the name and
1192      dtype of the variables arriving via the broadcast checkpoint.
1193    is_broadcast_empty: A boolean indicating whether the MapReduce form
1194      initially called for an empty broadcast. In this case the
1195      broadcasted_tensor_specs will contain a single tf.int32, but it will be
1196      ignored.
1197    experimental_checkpoint_write: Determines the format of the final client
1198      update checkpoint. The value affects required operations and might have
1199      performance implications.
1200
1201  Returns:
1202    A `tuple` of the client TensorFlow GraphDef and the client phase protocol
1203      message.
1204
1205  Raises:
1206    SecureAggregationTensorShapeError: If SecAgg tensors do not have all
1207      dimensions of their shape fully defined.
1208    ValueError: If any of the arguments are found to be in an unexpected form.
1209  """
1210  if (
1211      not isinstance(client_work_comp.type_signature.parameter, tff.StructType)
1212      or len(client_work_comp.type_signature.parameter) < 1
1213  ):
1214    raise ValueError(
1215        'client_work_comp.type_signature.parameter should be a '
1216        '`tff.StructType` with length >= 1, but found: {p}.'.format(
1217            p=client_work_comp.type_signature.parameter
1218        )
1219    )
1220
1221  if (
1222      not isinstance(client_work_comp.type_signature.result, tff.StructType)
1223      or len(client_work_comp.type_signature.result) != 4
1224  ):
1225    raise ValueError(
1226        'client_work_comp.type_signature.result should be a '
1227        '`tff.StructType` with length == 4, but found: {r}.'.format(
1228            r=client_work_comp.type_signature.result
1229        )
1230    )
1231
1232  (
1233      simpleagg_update_type,
1234      secure_sum_bitwidth_update_type,
1235      secure_sum_update_type,
1236      secure_modular_sum_update_type,
1237  ) = client_work_comp.type_signature.result
1238
1239  # A list of tensors that will be passed into TensorFlow, corresponding to
1240  # `plan_pb2.ClientPhase.tensorflow_spec.input_tensor_specs`. Note that the
1241  # dataset token is excluded from this list. In general, this list should
1242  # include the filepath placeholder tensors for the input checkpoint file and
1243  # output checkpoint file.
1244  input_tensors = []
1245
1246  # A list of tensor specs that should be fetched from TensorFlow, corresponding
1247  # to `plan_pb2.ClientPhase.tensorflow_spec.output_tensor_specs`. In general,
1248  # this list should include the tensors that are not in the output checkpoint
1249  # file, such as secure aggregation tensors.
1250  output_tensor_specs = []
1251
1252  # A list of node names in the client graph that should be executed but no
1253  # output returned, corresponding to
1254  # `plan_pb2.ClientPhase.tensorflow_spec.target_node_names`. In general, this
1255  # list should include the op that creates the output checkpoint file.
1256  target_nodes = []
1257  with tf.Graph().as_default() as client_graph:
1258    input_filepath_placeholder = None
1259    if not is_broadcast_empty:
1260      input_filepath_placeholder = tf.compat.v1.placeholder(
1261          name='input_filepath', shape=(), dtype=tf.string
1262      )
1263      weights_from_server = checkpoint_utils.restore_tensors_from_savepoint(
1264          broadcasted_tensor_specs, input_filepath_placeholder
1265      )
1266      input_tensors.append(input_filepath_placeholder)
1267    else:
1268      weights_from_server = []
1269
1270    # Add the custom Dataset ops to the graph.
1271    token_placeholder, data_values, example_selector_placeholders = (
1272        graph_helpers.embed_data_logic(
1273            client_work_comp.type_signature.parameter[0], dataspec
1274        )
1275    )
1276
1277    # Embed the graph coming from TFF into the client work graph.
1278    combined_update_tensors = graph_helpers.import_tensorflow(
1279        'work',
1280        client_work_comp,
1281        (data_values, weights_from_server),
1282        split_outputs=False,
1283        session_token_tensor=token_placeholder,
1284    )  # pytype: disable=wrong-arg-types
1285
1286    num_simpleagg_tensors = len(tff.structure.flatten(simpleagg_update_type))
1287    simpleagg_tensors = combined_update_tensors[:num_simpleagg_tensors]
1288    secagg_tensors = combined_update_tensors[num_simpleagg_tensors:]
1289
1290    # For tensors aggregated by secagg, we make sure the tensor names are
1291    # aligned in both client and sever graph by getting the names from the same
1292    # method.
1293    secagg_tensor_names = []
1294    secagg_tensor_types = []
1295    for uri, update_type in [
1296        (SECURE_SUM_BITWIDTH_URI, secure_sum_bitwidth_update_type),
1297        (SECURE_SUM_URI, secure_sum_update_type),
1298        (SECURE_MODULAR_SUM_URI, secure_modular_sum_update_type),
1299    ]:
1300      secagg_tensor_names += variable_helpers.get_shared_secagg_tensor_names(
1301          uri, update_type
1302      )
1303      secagg_tensor_types += tff.structure.flatten(update_type)
1304
1305    secagg_tensors = [
1306        tf.identity(tensor, name=tensor_utils.bare_name(name))
1307        for tensor, name in zip(secagg_tensors, secagg_tensor_names)
1308    ]
1309    for t, type_spec in zip(secagg_tensors, secagg_tensor_types):
1310      secagg_tensor_spec = proto_helpers.make_tensor_spec_from_tensor(
1311          t, shape_hint=type_spec.shape
1312      )
1313      output_tensor_specs.append(secagg_tensor_spec.experimental_as_proto())
1314
1315    # Verify that SecAgg Tensors have all dimension fully defined.
1316    for tensor_spec in output_tensor_specs:
1317      if not tf.TensorShape(tensor_spec.shape).is_fully_defined():
1318        raise SecureAggregationTensorShapeError(
1319            '`TensorflowSpec.output_tensor_specs` has unknown dimension.'
1320        )
1321
1322    output_filepath_placeholder = None
1323    if simpleagg_tensors:
1324      output_filepath_placeholder = tf.compat.v1.placeholder(
1325          dtype=tf.string, shape=(), name='output_filepath'
1326      )
1327      simpleagg_variable_names = variable_helpers.variable_names_from_type(
1328          simpleagg_update_type, name=artifact_constants.UPDATE
1329      )
1330      if experimental_checkpoint_write in [
1331          checkpoint_type.CheckpointFormatType.APPEND_SLICES_MERGE_WRITE,
1332          checkpoint_type.CheckpointFormatType.APPEND_SLICES_MERGE_READ,
1333      ]:
1334        delete_op = delete_file.delete_file(output_filepath_placeholder)
1335        with tf.control_dependencies([delete_op]):
1336          append_ops = []
1337          for tensor_name, tensor in zip(
1338              simpleagg_variable_names, simpleagg_tensors
1339          ):
1340            append_ops.append(
1341                tensor_utils.save(
1342                    filename=output_filepath_placeholder,
1343                    tensor_names=[tensor_name],
1344                    tensors=[tensor],
1345                    save_op=append_slices.append_slices,
1346                )
1347            )
1348        if (
1349            experimental_checkpoint_write
1350            == checkpoint_type.CheckpointFormatType.APPEND_SLICES_MERGE_WRITE
1351        ):
1352          with tf.control_dependencies(append_ops):
1353            save_op = append_slices.merge_appended_slices(
1354                filename=output_filepath_placeholder
1355            )
1356        else:
1357          # APPEND_SLICES_MERGE_READ
1358          save_op = tf.group(*append_ops)
1359
1360      elif (
1361          experimental_checkpoint_write
1362          == checkpoint_type.CheckpointFormatType.TF1_SAVE_SLICES
1363      ):
1364        save_op = tensor_utils.save(
1365            filename=output_filepath_placeholder,
1366            tensor_names=simpleagg_variable_names,
1367            tensors=simpleagg_tensors,
1368            name='save_client_update_tensors',
1369        )
1370      else:
1371        raise NotImplementedError(
1372            f'Unsupported CheckpointFormatType {experimental_checkpoint_write}.'
1373        )
1374      input_tensors.append(output_filepath_placeholder)
1375      target_nodes.append(save_op.name)
1376
1377  tensorflow_spec = plan_pb2.TensorflowSpec()
1378  if token_placeholder is not None:
1379    tensorflow_spec.dataset_token_tensor_name = token_placeholder.name
1380  if input_tensors:
1381    tensorflow_spec.input_tensor_specs.extend(
1382        tf.TensorSpec.from_tensor(t, name=t.name).experimental_as_proto()
1383        for t in input_tensors
1384    )
1385  if output_tensor_specs:
1386    tensorflow_spec.output_tensor_specs.extend(output_tensor_specs)
1387  if target_nodes:
1388    tensorflow_spec.target_node_names.extend(target_nodes)
1389  if example_selector_placeholders:
1390    for placeholder in example_selector_placeholders:
1391      # Generating the default TensorProto will create a TensorProto with an
1392      # DT_INVALID DType. This identifies that there is a placeholder that is
1393      # needed. In order to have the Plan proto be completely runnable, the
1394      # value will need to be filled in with a real TensorProto that matches
1395      # the shape/type of the expected input.
1396      tensorflow_spec.constant_inputs[placeholder.name].dtype = 0
1397
1398  io_router = plan_pb2.FederatedComputeIORouter()
1399  if input_filepath_placeholder is not None:
1400    io_router.input_filepath_tensor_name = input_filepath_placeholder.name
1401  if output_filepath_placeholder is not None:
1402    io_router.output_filepath_tensor_name = output_filepath_placeholder.name
1403  for secagg_tensor in secagg_tensors:
1404    io_router.aggregations[secagg_tensor.name].CopyFrom(
1405        plan_pb2.AggregationConfig(
1406            secure_aggregation=plan_pb2.SecureAggregationConfig()
1407        )
1408    )
1409
1410  return client_graph.as_graph_def(), plan_pb2.ClientPhase(
1411      tensorflow_spec=tensorflow_spec, federated_compute=io_router
1412  )
1413
1414
1415def _build_client_phase_with_example_query_spec(
1416    client_work_comp: tff.Computation,
1417    example_query_spec: plan_pb2.ExampleQuerySpec,
1418) -> plan_pb2.ClientPhase:
1419  """Builds the ClientPhase with `ExampleQuerySpec` populated.
1420
1421  Args:
1422    client_work_comp: A `tff.Computation` that represents the TensorFlow logic
1423      run on-device.
1424    example_query_spec: Field containing output vector information for client
1425      example query. The output vector names listed in the spec are expected to
1426      be consistent with the output names we would produce in the
1427      `MapReduceForm` client work computation, if we were to build a TF-based
1428      plan from that `MapReduceForm`.
1429
1430  Returns:
1431    A client phase part of the federated protocol.
1432  """
1433  expected_vector_names = set(
1434      variable_helpers.variable_names_from_type(
1435          client_work_comp.type_signature.result[0], artifact_constants.UPDATE
1436      )
1437  )
1438  used_names = set()
1439  io_router = plan_pb2.FederatedExampleQueryIORouter()
1440  for example_query in example_query_spec.example_queries:
1441    vector_names = set(example_query.output_vector_specs.keys())
1442    if not all([name in expected_vector_names for name in vector_names]):
1443      raise ValueError(
1444          'Found unexpected vector names in supplied `example_query_spec`. '
1445          f'Expected names: {expected_vector_names}. '
1446          f'Found unexpected names: {vector_names-expected_vector_names}.'
1447      )
1448
1449    if any([name in used_names for name in vector_names]):
1450      raise ValueError(
1451          'Duplicate vector names found in supplied `example_query_spec`. '
1452          f'Duplicates: {vector_names.intersection(used_names)}'
1453      )
1454
1455    used_names.update(vector_names)
1456
1457    for vector_name in vector_names:
1458      io_router.aggregations[vector_name].CopyFrom(
1459          plan_pb2.AggregationConfig(
1460              tf_v1_checkpoint_aggregation=plan_pb2.TFV1CheckpointAggregation()
1461          )
1462      )
1463
1464  if used_names != expected_vector_names:
1465    raise ValueError(
1466        'Not all expected vector names were in supplied `example_query_spec`.'
1467        f' Expected names: {expected_vector_names}. Names not present in'
1468        f' `example_query_spec`: {expected_vector_names-vector_names}'
1469    )
1470  return plan_pb2.ClientPhase(
1471      example_query_spec=example_query_spec, federated_example_query=io_router
1472  )
1473
1474
1475def build_plan(
1476    mrf: tff.backends.mapreduce.MapReduceForm,
1477    daf: Optional[tff.backends.mapreduce.DistributeAggregateForm] = None,
1478    dataspec: Optional[data_spec.NestedDataSpec] = None,
1479    example_query_spec: Optional[plan_pb2.ExampleQuerySpec] = None,
1480    grappler_config: Optional[tf.compat.v1.ConfigProto] = None,
1481    additional_checkpoint_metadata_var_fn: Optional[
1482        Callable[[tff.StructType, tff.StructType, bool], list[tf.Variable]]
1483    ] = None,
1484    experimental_client_checkpoint_write: checkpoint_type.CheckpointFormatType = checkpoint_type.CheckpointFormatType.TF1_SAVE_SLICES,
1485    generate_server_phase_v2: bool = False,
1486    write_metrics_to_checkpoint: bool = True,
1487) -> plan_pb2.Plan:
1488  """Constructs an instance of `plan_pb2.Plan` given a `MapReduceForm` instance.
1489
1490  Plans generated by this method are executable, but a number of features have
1491  yet to be implemented.
1492
1493  These include:
1494
1495  - Setting metrics' `stat_name` field based on externally-supplied metadata,
1496    such as that from the model stampers. Currently, these names are based on
1497    the names of TensorFlow variables, which in turn are based on the TFF
1498    type signatures.
1499
1500  - Populating the client `example_selector` field. Currently not set.
1501
1502  - Populating client-side `savepoint`. Currently not set.
1503
1504  - Populating the plan's `tensorflow_config_proto`. Currently not set.
1505
1506  - Setting a field in the plan that represets a token to drive the custom op
1507    that iplements the client-side dataset. There is no such field in the plan
1508    at the time of this writing.
1509
1510  - Populating plan fields related to secure aggregation and side channels,
1511    such as the `read_aggregated_update` checkpoint op.
1512
1513  Args:
1514    mrf: An instance of `tff.backends.mapreduce.MapReduceForm`.
1515    daf: An instance of `tff.backends.mapreduce.DistributeAggregateForm`.
1516    dataspec: If provided, either an instance of `data_spec.DataSpec` or a
1517      nested structure of these that matches the structure of the first element
1518      of the input to client-side processing computation `mrf.work`. If not
1519      provided and `example_query_spec` is also not provided, then placeholders
1520      are added to the client graph via `embed_data_logic()` and the example
1521      selectors will need to be passed to the client via the `constant_inputs`
1522      part of the `TensorflowSpec`. The constant_inputs field needs to be
1523      populated outside of `build_plan()`. Can only provide one of `dataspec` or
1524      `example_query_spec`.
1525    example_query_spec: An instance of `plan_pb2.ExampleQuerySpec`. If provided
1526      it is assumed a light weight client plan should be constructed. No client
1527      graph will be included in the produced plan object. Instead the generated
1528      plan will have an `ExampleQuerySpec` and `FederatedExampleQueryIORouter`.
1529      Can only supply one of `dataspec` or `example_query_spec`.
1530    grappler_config: The config specifying Grappler optimizations for TFF-
1531      generated graphs. Should be provided if daf is provided.
1532    additional_checkpoint_metadata_var_fn: An optional method that takes in a
1533      server state type, a server metrics type, and a boolean determining
1534      whether to revert to legacy metrics behavior to produce additional
1535      metadata variables.
1536    experimental_client_checkpoint_write: Determines the style of writing of the
1537      client checkpoint (client->server communication). The value affects the
1538      operation used and might have impact on overall task performance.
1539    generate_server_phase_v2: Iff `True`, will produce a ServerPhaseV2 message
1540      in the plan in addition to a ServerPhase message.
1541    write_metrics_to_checkpoint: If False, revert to legacy behavior where
1542      metrics values were handled by post-processing separate from the outputted
1543      checkpoint. Regardless, they will additionally continue to be written to
1544      recordio and accumulator checkpoints as defined by the Plan proto.
1545
1546  Returns:
1547    An instance of `plan_pb2.Plan` corresponding to MapReduce form `mrf`.
1548
1549  Raises:
1550    TypeError: If the arguments are of the wrong types.
1551    ValueError: If any of the arguments are found to be in an unexpected form.
1552  """
1553  type_checks.check_type(mrf, tff.backends.mapreduce.MapReduceForm, name='mrf')
1554  client_plan_type = (
1555      ClientPlanType.TENSORFLOW
1556      if example_query_spec is None
1557      else ClientPlanType.EXAMPLE_QUERY
1558  )
1559
1560  if example_query_spec is not None:
1561    if dataspec is not None:
1562      raise ValueError(
1563          '`example_query_spec` or `dataspec` cannot both be specified.'
1564      )
1565
1566  with tff.framework.get_context_stack().install(
1567      tff.test.create_runtime_error_context()
1568  ):
1569    is_broadcast_empty = (
1570        isinstance(mrf.prepare.type_signature.result, tff.StructType)
1571        and not mrf.prepare.type_signature.result
1572    )
1573    if is_broadcast_empty:
1574      # This MapReduceForm does not send any server state to clients, however we
1575      # need something to satisfy current restrictions from the FCP server.
1576      # Use a placeholder scalar int.
1577      broadcast_tff_type = tff.TensorType(tf.int32)
1578    else:
1579      broadcast_tff_type = mrf.prepare.type_signature.result
1580
1581    # Execute the bitwidths TFF computation using the default TFF executor.
1582    bitwidths, max_inputs, moduli = _compute_secagg_parameters(mrf)
1583    # Note: the variables below are flat lists, even though
1584    # `secure_sum_bitwidth_update_type`
1585    # could potentially represent a large group of nested tensors. In order
1586    # for each var to line up with the appropriate bitwidth, we must also
1587    # flatten the list of bitwidths.
1588    flattened_bitwidths = tff.structure.flatten(bitwidths)
1589    flattened_max_inputs = tff.structure.flatten(max_inputs)
1590    flattened_moduli = tff.structure.flatten(moduli)
1591
1592    (
1593        server_graph_def,
1594        server_savepoint,
1595        server_phase,
1596        broadcasted_tensor_specs,
1597    ) = _build_server_graph(
1598        mrf,
1599        broadcast_tff_type,
1600        is_broadcast_empty,
1601        flattened_bitwidths,
1602        flattened_max_inputs,
1603        flattened_moduli,
1604        write_metrics_to_checkpoint,
1605        additional_checkpoint_metadata_var_fn,
1606        experimental_client_update_format=experimental_client_checkpoint_write,
1607    )
1608
1609    if client_plan_type == ClientPlanType.TENSORFLOW:
1610      client_graph_def, client_phase = _build_client_graph_with_tensorflow_spec(
1611          mrf.work,
1612          dataspec,
1613          broadcasted_tensor_specs,
1614          is_broadcast_empty,
1615          experimental_checkpoint_write=experimental_client_checkpoint_write,
1616      )
1617    elif client_plan_type == ClientPlanType.EXAMPLE_QUERY:
1618      client_phase = _build_client_phase_with_example_query_spec(
1619          mrf.work, example_query_spec
1620      )
1621    else:
1622      raise ValueError(
1623          f'Unexpected value for `client_plan_type`: {client_plan_type}'
1624      )
1625
1626    combined_phases = plan_pb2.Plan.Phase(
1627        server_phase=server_phase, client_phase=client_phase
1628    )
1629
1630    if generate_server_phase_v2:
1631      assert daf
1632      assert grappler_config
1633      (server_graph_def_prepare, server_graph_def_result, server_phase_v2) = (
1634          _build_server_graphs_from_distribute_aggregate_form(
1635              daf, is_broadcast_empty, grappler_config
1636          )
1637      )
1638      combined_phases.server_phase_v2.CopyFrom(server_phase_v2)
1639
1640    plan = plan_pb2.Plan(
1641        version=1, server_savepoint=server_savepoint, phase=[combined_phases]
1642    )
1643
1644    plan.server_graph_bytes.Pack(server_graph_def)
1645    if client_plan_type == ClientPlanType.TENSORFLOW:
1646      plan.client_graph_bytes.Pack(client_graph_def)
1647
1648    if generate_server_phase_v2:
1649      plan.server_graph_prepare_bytes.Pack(server_graph_def_prepare)
1650      plan.server_graph_result_bytes.Pack(server_graph_def_result)
1651  return plan
1652
1653
1654def build_cross_round_aggregation_execution(
1655    mrf: tff.backends.mapreduce.MapReduceForm,
1656) -> bytes:
1657  """Constructs an instance of `plan_pb2.CrossRoundAggregationExecution`.
1658
1659  Args:
1660    mrf: An instance of `tff.backends.mapreduce.MapReduceForm`.
1661
1662  Returns:
1663    A serialized instance of `plan_pb2.CrossRoundAggregationExecution` for given
1664    `mrf`.
1665  """
1666  type_checks.check_type(mrf, tff.backends.mapreduce.MapReduceForm, name='mrf')
1667
1668  server_metrics_type = mrf.update.type_signature.result[1]
1669  (
1670      simpleagg_update_type,
1671      secure_sum_bitwidth_update_type,
1672      secure_sum_update_type,
1673      secure_modular_sum_update_type,
1674  ) = mrf.work.type_signature.result
1675  # We don't ever work directly on `simpleagg_update_type` because client
1676  # updates are transformed by `accumulate` and `merge` before ever being passed
1677  # into cross-round aggregation.
1678  del simpleagg_update_type
1679  simpleagg_merge_type = mrf.merge.type_signature.result
1680  flattened_moduli = tff.structure.flatten(mrf.secure_modular_sum_modulus())
1681
1682  if not server_metrics_type:
1683    # No metrics to aggregrate; will initialize to no-op.
1684    server_metrics_type = tff.StructType([])
1685  elif isinstance(server_metrics_type, tff.TensorType):
1686    # Single tensor metric; must be wrapped inside of a NamedTuple for proper
1687    # variable initialiazation.
1688    server_metrics_type = tff.StructType([server_metrics_type])
1689  combined_aggregated_update_type = tff.StructType([
1690      simpleagg_merge_type,
1691      secure_sum_bitwidth_update_type,
1692      secure_sum_update_type,
1693      secure_modular_sum_update_type,
1694  ])
1695
1696  with tf.Graph().as_default() as cross_round_aggregation_graph:
1697    server_state_vars = variable_helpers.create_vars_for_tff_type(
1698        mrf.update.type_signature.parameter[0],
1699        artifact_constants.SERVER_STATE_VAR_PREFIX,
1700    )
1701
1702    combined_aggregated_update_vars, read_aggregated_update = (
1703        checkpoint_utils.create_state_vars_and_savepoint(
1704            combined_aggregated_update_type, 'aggregated_update'
1705        )
1706    )
1707
1708    num_simpleagg_vars = len(tff.structure.flatten(simpleagg_merge_type))
1709
1710    aggregated_update_vars = combined_aggregated_update_vars[
1711        :num_simpleagg_vars
1712    ]
1713    secagg_aggregated_update_vars = combined_aggregated_update_vars[
1714        num_simpleagg_vars:
1715    ]
1716
1717    # Add a new output for metrics_loader `merge` and `report`.
1718    combined_final_accumulator_vars, read_write_final_accumulators = (
1719        checkpoint_utils.create_state_vars_and_savepoint(
1720            combined_aggregated_update_type, 'final_accumulators'
1721        )
1722    )
1723
1724    final_accumulator_vars = combined_final_accumulator_vars[
1725        :num_simpleagg_vars
1726    ]
1727    secagg_final_accumulator_vars = combined_final_accumulator_vars[
1728        num_simpleagg_vars:
1729    ]
1730
1731    var_init_op = tf.compat.v1.initializers.variables(
1732        server_state_vars
1733        + combined_aggregated_update_vars
1734        + combined_final_accumulator_vars
1735    )
1736
1737    # Embeds the MapReduce form `merge` logic.
1738    merged_values = graph_helpers.import_tensorflow(
1739        'merge', mrf.merge, (final_accumulator_vars, aggregated_update_vars)
1740    )
1741    final_accumulator_assign_ops = tf.nest.map_structure(
1742        lambda variable, tensor: variable.assign(tensor),
1743        final_accumulator_vars,
1744        merged_values,
1745    )
1746
1747    # SecAgg tensors' aggregation is not provided in the imported TensorFlow,
1748    # but is instead fixed based on the operator (e.g. `assign_add` for
1749    # variables passed into `secure_sum`).
1750    secagg_final_accumulator_ops = _merge_secagg_vars(
1751        secure_sum_bitwidth_update_type,
1752        secure_sum_update_type,
1753        secure_modular_sum_update_type,
1754        flattened_moduli,
1755        secagg_final_accumulator_vars,
1756        secagg_aggregated_update_vars,
1757    )
1758    final_accumulator_op = tf.group(
1759        *(final_accumulator_assign_ops + secagg_final_accumulator_ops)
1760    ).name
1761
1762    # Embeds the `report` and `update` logic, and hooks up the assignments of
1763    # the results of the final update to the server state and metric vars, to
1764    # be triggered by `apply_aggregrated_updates_op`.
1765    simpleagg_reported_values = graph_helpers.import_tensorflow(
1766        'report', mrf.report, final_accumulator_vars
1767    )
1768    combined_final_vars = (
1769        simpleagg_reported_values + secagg_final_accumulator_vars
1770    )
1771    (_, server_metric_values) = graph_helpers.import_tensorflow(
1772        artifact_constants.UPDATE,
1773        mrf.update,
1774        (server_state_vars, combined_final_vars),
1775        split_outputs=True,
1776    )
1777
1778    server_metrics_names = variable_helpers.variable_names_from_type(
1779        server_metrics_type, name=artifact_constants.SERVER_STATE_VAR_PREFIX
1780    )
1781
1782    flattened_metrics_types = tff.structure.flatten(server_metrics_type)
1783    measurements = [
1784        proto_helpers.make_measurement(v, s, a)
1785        for v, s, a in zip(
1786            server_metric_values, server_metrics_names, flattened_metrics_types
1787        )
1788    ]
1789
1790  cross_round_aggregation_execution = plan_pb2.CrossRoundAggregationExecution(
1791      init_op=var_init_op.name,
1792      read_aggregated_update=read_aggregated_update,
1793      merge_op=final_accumulator_op,
1794      read_write_final_accumulators=read_write_final_accumulators,
1795      measurements=measurements,
1796  )
1797
1798  cross_round_aggregation_execution.cross_round_aggregation_graph_bytes.Pack(
1799      cross_round_aggregation_graph.as_graph_def()
1800  )
1801
1802  return cross_round_aggregation_execution.SerializeToString()
1803