1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""A library responsible for building Federated Compute plans. 15 16This library builds TFF-backed plans, using the `MapReduceForm` object 17output by the TFF compiler pipeline. 18""" 19 20import collections 21from collections.abc import Callable, Iterable, Mapping, Sequence 22import enum 23from typing import Optional, TypeVar, Union 24 25import attr 26import tensorflow as tf 27import tensorflow_federated as tff 28 29from fcp.artifact_building import artifact_constants 30from fcp.artifact_building import checkpoint_type 31from fcp.artifact_building import checkpoint_utils 32from fcp.artifact_building import data_spec 33from fcp.artifact_building import graph_helpers 34from fcp.artifact_building import proto_helpers 35from fcp.artifact_building import tensor_utils 36from fcp.artifact_building import type_checks 37from fcp.artifact_building import variable_helpers 38from fcp.protos import plan_pb2 39from fcp.tensorflow import append_slices 40from fcp.tensorflow import delete_file 41 42SECURE_SUM_BITWIDTH_URI = 'federated_secure_sum_bitwidth' 43SECURE_SUM_URI = 'federated_secure_sum' 44SECURE_MODULAR_SUM_URI = 'federated_secure_modular_sum' 45 46 47class SecureAggregationTensorShapeError(Exception): 48 """Error raised when secagg tensors do not have fully defined shape.""" 49 50 51@enum.unique 52class ClientPlanType(enum.Enum): 53 """Option adjusting client plan type during plan building. 54 55 Values: 56 TENSORFLOW: The default value. Uses a TF client graph for client 57 computation. 58 EXAMPLE_QUERY: Uses an example query containing client computation logic in 59 the provided example selector(s). 60 """ 61 62 TENSORFLOW = 'tensorflow' 63 EXAMPLE_QUERY = 'example_query' 64 65 66# A type representing a potentially nested struct of integers. 67IntStruct = Union[ 68 int, 69 Mapping[str, Union['IntStruct', int]], 70 Sequence[Union['IntStruct', int]], 71] 72 73 74def _compute_secagg_parameters( 75 mrf: tff.backends.mapreduce.MapReduceForm, 76) -> tuple[IntStruct, IntStruct, IntStruct]: 77 """Executes the TensorFlow logic that computes the SecAgg parameters. 78 79 This function makes use of `mrf.secure_sum_bitwidth`, 80 `mrf.secure_sum_max_input`, and `mrf.secure_modular_sum_modulus` to derive 81 the parameters needed for the SecAgg protocol variants. 82 83 Args: 84 mrf: An instance of `tff.backends.mapreduce.MapReduceForm`. 85 86 Returns: 87 A 3-tuple of `bitwidth`, `max_input` and `moduli` structures of parameters 88 for the associated SecAgg variants. 89 """ 90 type_checks.check_type(mrf, tff.backends.mapreduce.MapReduceForm, name='mrf') 91 secagg_parameters = [] 92 with tf.Graph().as_default() as g: 93 for name, computation in [ 94 ('bitwidth', mrf.secure_sum_bitwidth), 95 ('max_input', mrf.secure_sum_max_input), 96 ('modulus', mrf.secure_modular_sum_modulus), 97 ]: 98 secagg_parameters.append( 99 graph_helpers.import_tensorflow(name, computation) 100 ) 101 with tf.compat.v1.Session(graph=g) as sess: 102 flat_output = sess.run(fetches=tf.nest.flatten(secagg_parameters)) 103 return tf.nest.pack_sequence_as(secagg_parameters, flat_output) 104 105 106# A side-channel through which one tensor is securely aggregated. 107@attr.s 108class SecAggSideChannel: 109 # The name of the tensor being aggregated in the client and server graphs. 110 tensor_name: str = attr.ib() 111 # A proto describing how the side-channel is to be aggregated. 112 side_channel_proto: plan_pb2.SideChannel = attr.ib() 113 # A placeholder tensor into which the sidechannel aggregation is filled. 114 placeholder: tf.Tensor = attr.ib() 115 # The variable to feed into the server graph. 116 update_var: tf.Variable = attr.ib() 117 118 119SecAggParam = TypeVar('SecAggParam') 120 121 122def _create_secagg_sidechannels( 123 intrinsic_name: str, 124 update_type: variable_helpers.AllowedTffTypes, 125 get_modulus_scheme: Callable[ 126 [SecAggParam], plan_pb2.SideChannel.SecureAggregand 127 ], 128 params: list[SecAggParam], 129) -> list[SecAggSideChannel]: 130 """Returns `SecAggSideChannel`s for tensors aggregated with `intrinsic_name`. 131 132 This method also creates variables for the securely-aggregated tensors within 133 the current default graph using `create_vars_for_tff_type`. 134 135 Args: 136 intrinsic_name: The name of the intrinsic (e.g. 137 `federated_secure_sum_bitwidth`) with which the tensors in `update_type` 138 are being aggregated. 139 update_type: The TFF type representing a structure of all tensors being 140 aggregated with `intrinsic_name`. 141 get_modulus_scheme: A function which will get the modulus scheme being used. 142 This typically requires some additional per-tensor parameters which must 143 be supplied using `params`. 144 params: A list of arguments to pass to `set_modulus_scheme`. There must be 145 exactly one element in this list per tensor in `update_type`. 146 147 Returns: 148 A list of `SecAggSideChannel`s describing how to aggregate each tensor. 149 """ 150 # For secure aggregation, we don't use a saver (but still store metadata in a 151 # CheckpointOp). Instead we create sidechannel tensors that get fed into the 152 # server graph. 153 update_vars = variable_helpers.create_vars_for_tff_type( 154 update_type, f'{intrinsic_name}_update' 155 ) 156 157 # For tensors aggregated by secagg, we make sure the tensor names are aligned 158 # in both client and sever graph by getting the names from the same method. 159 tensor_names = variable_helpers.get_shared_secagg_tensor_names( 160 intrinsic_name, update_type 161 ) 162 assert len(update_vars) == len(params) == len(tensor_names), ( 163 'The length of update_vars, params and tensor_names for' 164 f' {{intrinsic_name}} should be all equal, but found: {len(update_vars)},' 165 f' {len(params)}, and {len(tensor_names)}.' 166 ) 167 168 results = [] 169 for param, update_var, tensor_name in zip(params, update_vars, tensor_names): 170 secure_aggregand = get_modulus_scheme(param) 171 secure_aggregand.dimension.extend( 172 plan_pb2.SideChannel.SecureAggregand.Dimension(size=d.value) 173 for d in update_var.shape.dims 174 ) 175 secure_aggregand.dtype = update_var.dtype.base_dtype.as_datatype_enum 176 placeholder = tf.compat.v1.placeholder( 177 update_var.dtype, update_var.get_shape() 178 ) 179 side_channel_proto = plan_pb2.SideChannel( 180 secure_aggregand=secure_aggregand, restore_name=placeholder.name 181 ) 182 results.append( 183 SecAggSideChannel( 184 tensor_name=tensor_name, 185 side_channel_proto=side_channel_proto, 186 placeholder=placeholder, 187 update_var=update_var, 188 ) 189 ) 190 return results 191 192 193def _read_secagg_update_from_sidechannel_into_vars( 194 *, # Require parameters to be named. 195 secagg_intermediate_update_vars: list[tf.Variable], 196 secure_sum_bitwidth_update_type: (variable_helpers.AllowedTffTypes), 197 bitwidths: list[int], 198 secure_sum_update_type: (variable_helpers.AllowedTffTypes), 199 max_inputs: list[int], 200 secure_modular_sum_update_type: (variable_helpers.AllowedTffTypes), 201 moduli: list[int], 202) -> plan_pb2.CheckpointOp: 203 """Creates the `read_secagg_update` op. 204 205 `read_secagg_update` is a `plan_pb2.CheckpointOp` and used to restore the 206 secagg tensors in server graph. 207 208 Args: 209 secagg_intermediate_update_vars: A list of variables to assign the 210 secagg_update_data in the `after_restore_op`. 211 secure_sum_bitwidth_update_type: The type of the tensors aggregated using 212 `bitwidth`-based secure sum. 213 bitwidths: The `bitwidth`s for the tensors that will be aggregated using 214 `bitwidth`-based secure summation. 215 secure_sum_update_type: The type of the tensors aggregated using 216 `max_input`-based secure sum. 217 max_inputs: The max_input`s for the tensors that will be aggregated using 218 `max_input`-based secure summation. 219 secure_modular_sum_update_type: The type of the tensors aggregated using 220 modular secure summation. 221 moduli: The `modulus`s for the tensors that will be aggregated using modular 222 secure summation. 223 224 Returns: 225 A `plan_pb2.CheckpointOp` which performs the `read_secagg_update`. 226 """ 227 side_channels: list[SecAggSideChannel] = [] 228 229 def _aggregand_for_bitwidth(bitwidth): 230 return plan_pb2.SideChannel.SecureAggregand( 231 quantized_input_bitwidth=bitwidth 232 ) 233 234 side_channels += _create_secagg_sidechannels( 235 SECURE_SUM_BITWIDTH_URI, 236 secure_sum_bitwidth_update_type, 237 _aggregand_for_bitwidth, 238 bitwidths, 239 ) 240 241 def _aggregand_for_max_input(max_input): 242 # Note the +1-- `max_input` is inclusive, so `base_modulus == max_input` 243 # would overflow maximum-valued inputs to zero. 244 base_modulus = max_input + 1 245 modulus_times_shard_size = ( 246 plan_pb2.SideChannel.SecureAggregand.ModulusTimesShardSize( 247 base_modulus=base_modulus 248 ) 249 ) 250 return plan_pb2.SideChannel.SecureAggregand( 251 modulus_times_shard_size=modulus_times_shard_size 252 ) 253 254 side_channels += _create_secagg_sidechannels( 255 SECURE_SUM_URI, 256 secure_sum_update_type, 257 _aggregand_for_max_input, 258 max_inputs, 259 ) 260 261 def _aggregand_for_modulus(modulus): 262 fixed_modulus = plan_pb2.SideChannel.SecureAggregand.FixedModulus( 263 modulus=modulus 264 ) 265 return plan_pb2.SideChannel.SecureAggregand(fixed_modulus=fixed_modulus) 266 267 side_channels += _create_secagg_sidechannels( 268 SECURE_MODULAR_SUM_URI, 269 secure_modular_sum_update_type, 270 _aggregand_for_modulus, 271 moduli, 272 ) 273 274 # Operations assigning from sidechannel placeholders to update variables. 275 assign_placeholders_to_updates = [] 276 # Operations assigning from update variables to the result variables. 277 assign_updates_to_intermediate = [] 278 read_secagg_update = plan_pb2.CheckpointOp() 279 for intermediate_update_var, side_channel in zip( 280 secagg_intermediate_update_vars, side_channels 281 ): 282 assign_placeholders_to_updates.append( 283 side_channel.update_var.assign(side_channel.placeholder) 284 ) 285 assign_updates_to_intermediate.append( 286 intermediate_update_var.assign(side_channel.update_var) 287 ) 288 read_secagg_update.side_channel_tensors[side_channel.tensor_name].CopyFrom( 289 side_channel.side_channel_proto 290 ) 291 292 read_secagg_update.before_restore_op = tf.group( 293 *(assign_placeholders_to_updates) 294 ).name 295 read_secagg_update.after_restore_op = tf.group( 296 *(assign_updates_to_intermediate) 297 ).name 298 299 return read_secagg_update 300 301 302def _merge_secagg_vars( 303 secure_sum_bitwidth_update_type: tff.Type, 304 secure_sum_update_type: tff.Type, 305 secure_modular_sum_update_type: tff.Type, 306 flattened_moduli: list[int], 307 variables: list[tf.Variable], 308 tensors: list[tf.Variable], 309) -> list[tf.Operation]: 310 """Generates a set of ops to `merge` secagg `tensors` into `variables`.""" 311 if len(variables) != len(tensors): 312 raise ValueError( 313 'Expected an equal number of variables and tensors, but found ' 314 f'{len(variables)} variables and {len(tensors)} tensors.' 315 ) 316 num_simple_add_vars = len( 317 tff.structure.flatten( 318 tff.to_type([ 319 secure_sum_bitwidth_update_type, 320 secure_sum_update_type, 321 ]) 322 ) 323 ) 324 num_modular_add_vars = len( 325 tff.structure.flatten(secure_modular_sum_update_type) 326 ) 327 # There must be one variable and tensor for each tensor in the secure update 328 # types. 329 num_vars_from_types = num_simple_add_vars + num_modular_add_vars 330 if num_vars_from_types != len(variables): 331 raise ValueError( 332 'Expected one variable for each leaf element of the secagg update, but ' 333 f'found {len(variables)} variables and {num_vars_from_types} leaf ' 334 'elements in the following types:\n' 335 f'secure_sum_bitwidth_update_type: {secure_sum_bitwidth_update_type}\n' 336 f'secure_sum_update_type: {secure_sum_update_type}\n' 337 f'secure_modular_sum_update_type: {secure_modular_sum_update_type}\n' 338 ) 339 if num_modular_add_vars != len(flattened_moduli): 340 raise ValueError( 341 'Expected one modulus for each leaf element of the secure modular sum ' 342 f'update type. Found {len(flattened_moduli)} moduli and ' 343 f'{num_modular_add_vars} leaf elements in the secure modular sum ' 344 f'update type:\n{secure_modular_sum_update_type}' 345 ) 346 # Add `tensors` to `vars`, using simple addition for the first 347 # `num_secagg_simple_add_vars` variables and modular addition for the rest 348 # (those coming from `secure_modular_sum`). 349 ops = [] 350 simple_add_vars = variables[:num_simple_add_vars] 351 simple_add_tensors = tensors[:num_simple_add_vars] 352 for variable, tensor in zip(simple_add_vars, simple_add_tensors): 353 ops.append(variable.assign_add(tensor)) 354 modular_add_vars = variables[num_simple_add_vars:] 355 modular_add_tensors = tensors[num_simple_add_vars:] 356 for modulus, (variable, tensor) in zip( 357 flattened_moduli, zip(modular_add_vars, modular_add_tensors) 358 ): 359 new_sum = tf.math.add(variable.read_value(), tensor) 360 modular_sum = tf.math.floormod(new_sum, modulus) 361 ops.append(variable.assign(tf.reshape(modular_sum, tf.shape(variable)))) 362 return ops 363 364 365def _build_server_graphs_from_distribute_aggregate_form( 366 daf: tff.backends.mapreduce.DistributeAggregateForm, 367 is_broadcast_empty: bool, 368 grappler_config: tf.compat.v1.ConfigProto, 369) -> tuple[ 370 tf.compat.v1.GraphDef, tf.compat.v1.GraphDef, plan_pb2.ServerPhaseV2 371]: 372 """Generates the server plan components based on DistributeAggregateForm. 373 374 Derives the pre-broadcast, aggregation, and post-aggregation logical 375 components in the ServerPhaseV2 message that will be executed on the server. 376 The pre-broadcast and post-aggregation components are to be executed with a 377 single TF sess.run call using the corresponding GraphDef. The aggregation 378 component is to be executed natively (i.e. not using TensorFlow) according to 379 the aggregation messages contained in the ServerPhaseV2 message. 380 381 Args: 382 daf: An instance of `tff.backends.mapreduce.DistributeAggregateForm`. 383 is_broadcast_empty: A boolean indicating whether the broadcasted value from 384 the server is expected to be empty based on the DistributeAggregateForm, 385 in which case the server should broadcast a placeholder tf.int32 tensor as 386 empty checkpoints are not supported. 387 grappler_config: The config specifying Grappler optimizations for TFF- 388 generated graphs. 389 390 Returns: 391 A tuple containing the server_prepare GraphDef, the server_result GraphDef, 392 and the ServerPhaseV2 message. 393 """ 394 # Generate the TensorFlow graph needed to execute the server_prepare step, 395 # including reading input checkpoints and writing output checkpoints. 396 server_prepare_input_tensors = [] 397 server_prepare_target_nodes = [] 398 with tf.Graph().as_default() as server_prepare_graph: 399 # Create the placeholders for the input and output filenames needed by 400 # the server_prepare step. 401 server_prepare_server_state_input_filepath_placeholder = ( 402 tf.compat.v1.placeholder( 403 name='server_state_input_filepath', shape=(), dtype=tf.string 404 ) 405 ) 406 server_prepare_output_filepath_placeholder = tf.compat.v1.placeholder( 407 name='server_prepare_output_filepath', shape=(), dtype=tf.string 408 ) 409 server_prepare_intermediate_state_output_filepath_placeholder = ( 410 tf.compat.v1.placeholder( 411 name='server_intermediate_state_output_filepath', 412 shape=(), 413 dtype=tf.string, 414 ) 415 ) 416 server_prepare_input_tensors.extend([ 417 server_prepare_server_state_input_filepath_placeholder, 418 server_prepare_output_filepath_placeholder, 419 server_prepare_intermediate_state_output_filepath_placeholder, 420 ]) 421 422 # Restore the server state. 423 server_state_type = daf.server_prepare.type_signature.parameter 424 server_state_vars = variable_helpers.create_vars_for_tff_type( 425 server_state_type, name='server' 426 ) 427 server_state_tensor_specs = tf.nest.map_structure( 428 variable_helpers.tensorspec_from_var, server_state_vars 429 ) 430 server_state = checkpoint_utils.restore_tensors_from_savepoint( 431 server_state_tensor_specs, 432 server_prepare_server_state_input_filepath_placeholder, 433 ) 434 435 # TODO(team): Add support for federated select slice generation. 436 437 # Perform the server_prepare step. 438 prepared_values, intermediate_state_values = ( 439 graph_helpers.import_tensorflow( 440 'server_prepare', 441 tff.framework.ConcreteComputation.from_building_block( 442 tff.backends.mapreduce.consolidate_and_extract_local_processing( 443 daf.server_prepare.to_building_block(), grappler_config 444 ) 445 ), 446 server_state, 447 split_outputs=True, 448 ) 449 ) 450 451 # Create checkpoints storing the broadcast values and intermediate server 452 # state. If there is no broadcast value, create a checkpoint containing a 453 # placeholder tf.int32 constant since empty broadcasts are not supported. 454 # If there is no intermediate server state, don't create an intermediate 455 # server state checkpoint. 456 save_tensor_names = variable_helpers.variable_names_from_type( 457 daf.server_prepare.type_signature.result[0], name='client' 458 ) 459 save_tensors = prepared_values 460 if is_broadcast_empty: 461 save_tensor_names = variable_helpers.variable_names_from_type( 462 tff.StructType([tf.int32]), name='client' 463 ) 464 save_tensors = [tf.constant(0, tf.int32)] 465 prepared_values_save_op = tensor_utils.save( 466 filename=server_prepare_output_filepath_placeholder, 467 tensor_names=save_tensor_names, 468 tensors=save_tensors, 469 name='save_prepared_values_tensors', 470 ) 471 server_prepare_target_nodes.append(prepared_values_save_op.name) 472 473 intermediate_state_empty = ( 474 isinstance(daf.server_prepare.type_signature.result[1], tff.StructType) 475 and not daf.server_prepare.type_signature.result[1] 476 ) 477 if not intermediate_state_empty: 478 intermediate_state_values_save_op = tensor_utils.save( 479 filename=server_prepare_intermediate_state_output_filepath_placeholder, 480 tensor_names=variable_helpers.variable_names_from_type( 481 daf.server_prepare.type_signature.result[1], 'intermediate_state' 482 ), 483 tensors=intermediate_state_values, 484 name='save_intermediate_state_values_tensors', 485 ) 486 server_prepare_target_nodes.append(intermediate_state_values_save_op.name) 487 488 # Build aggregations. 489 # The client_to_server_aggregation computation is guaranteed to conform to 490 # a specific structure. It is a lambda computation whose result block contains 491 # locals that are exclusively aggregation-type intrinsics. 492 aggregations_bb = daf.client_to_server_aggregation.to_building_block() 493 aggregations_bb.check_lambda() 494 aggregations_bb.result.check_block() # pytype: disable=attribute-error 495 496 # Get lists of the TensorSpecProtos for the inputs and outputs of all 497 # intrinsic calls. These lists are formatted such that the ith entry 498 # represents the TensorSpecProtos for the ith intrinsic in the aggregation 499 # computation. Since intrinsics may have one or more args, the ith entry in 500 # the input TensorSpecProto list is itself a list, where the jth entry 501 # represents the TensorSpecProtos corresponding to the jth argument of the 502 # ith intrinsic. 503 grouped_input_tensor_specs = variable_helpers.get_grouped_input_tensor_specs_for_aggregations( 504 aggregations_bb, 505 artifact_constants.AGGREGATION_INTRINSIC_ARG_SELECTION_INDEX_TO_NAME_DICT, 506 ) 507 grouped_output_tensor_specs = ( 508 variable_helpers.get_grouped_output_tensor_specs_for_aggregations( 509 aggregations_bb 510 ) 511 ) 512 assert len(grouped_input_tensor_specs) == len(grouped_output_tensor_specs) 513 514 intrinsic_uris = [ 515 local_value.function.intrinsic_def().uri 516 for _, local_value in aggregations_bb.result.locals # pytype: disable=attribute-error 517 ] 518 assert len(intrinsic_uris) == len(grouped_output_tensor_specs) 519 520 # Each intrinsic input arg can be a struct or even a nested struct, which 521 # requires the intrinsic to be applied independently to each element (e.g. a 522 # tff.federated_sum call applied to a struct will result in a federated_sum 523 # aggregation message for each element of the struct). Note that elements of 524 # structs can themselves be multi-dimensional tensors. When an intrinsic call 525 # has multiple args with mismatching structure (e.g. a federated_weighted_mean 526 # intrinsic applied to a 2D struct value arg and scalar weight arg), some args 527 # will need to be "scaled up" via repetition to match the args with the 528 # "largest" structure. 529 aggregations = [] 530 for intrinsic_index, (input_tensor_specs, output_tensor_specs) in enumerate( 531 zip(grouped_input_tensor_specs, grouped_output_tensor_specs) 532 ): 533 # Generate the aggregation messages for this intrinsic call. 534 max_input_struct_length = max([len(x) for x in input_tensor_specs]) 535 max_struct_length = max(max_input_struct_length, len(output_tensor_specs)) 536 for i in range(max_struct_length): 537 intrinsic_args = [] 538 for j, _ in enumerate(input_tensor_specs): 539 # Scale up any "smaller" structure args by reusing their last element. 540 tensor_spec = input_tensor_specs[j][ 541 min(i, len(input_tensor_specs[j]) - 1) 542 ] 543 if tensor_spec.name.startswith('update'): 544 intrinsic_args.append( 545 plan_pb2.ServerAggregationConfig.IntrinsicArg( 546 input_tensor=tensor_spec.experimental_as_proto() 547 ) 548 ) 549 else: 550 intrinsic_args.append( 551 plan_pb2.ServerAggregationConfig.IntrinsicArg( 552 state_tensor=tensor_spec.experimental_as_proto() 553 ) 554 ) 555 aggregations.append( 556 plan_pb2.ServerAggregationConfig( 557 intrinsic_uri=intrinsic_uris[intrinsic_index], 558 intrinsic_args=intrinsic_args, 559 # Scale up the output structure by reusing the last element if 560 # needed. 561 output_tensors=[ 562 output_tensor_specs[ 563 min(i, len(output_tensor_specs) - 1) 564 ].experimental_as_proto() 565 ], 566 ) 567 ) 568 569 # Generate the TensorFlow graph needed to execute the server_result step, 570 # including reading input checkpoints, writing output checkpoints, and 571 # generating output tensors. 572 server_result_input_tensors = [] 573 server_result_output_tensors = [] 574 server_result_target_nodes = [] 575 with tf.Graph().as_default() as server_result_graph: 576 # Create the placeholders for the input and output filenames needed by 577 # the server_result step. 578 server_result_intermediate_state_input_filepath_placeholder = ( 579 tf.compat.v1.placeholder( 580 name='server_intermediate_state_input_filepath', 581 shape=(), 582 dtype=tf.string, 583 ) 584 ) 585 server_result_aggregate_result_input_filepath_placeholder = ( 586 tf.compat.v1.placeholder( 587 name='aggregate_result_input_filepath', shape=(), dtype=tf.string 588 ) 589 ) 590 server_result_server_state_output_filepath_placeholder = ( 591 tf.compat.v1.placeholder( 592 name='server_state_output_filepath', shape=(), dtype=tf.string 593 ) 594 ) 595 server_result_input_tensors.extend([ 596 server_result_intermediate_state_input_filepath_placeholder, 597 server_result_aggregate_result_input_filepath_placeholder, 598 server_result_server_state_output_filepath_placeholder, 599 ]) 600 601 # Restore the intermediate server state. 602 intermediate_state = [] 603 if not intermediate_state_empty: 604 intermediate_state_type = daf.server_result.type_signature.parameter[0] 605 intermediate_state_vars = variable_helpers.create_vars_for_tff_type( 606 intermediate_state_type, 'intermediate_state' 607 ) 608 intermediate_state_tensor_specs = tf.nest.map_structure( 609 variable_helpers.tensorspec_from_var, intermediate_state_vars 610 ) 611 intermediate_state = checkpoint_utils.restore_tensors_from_savepoint( 612 intermediate_state_tensor_specs, 613 server_result_intermediate_state_input_filepath_placeholder, 614 ) 615 616 # Restore the aggregation results. 617 aggregate_result_type = tff.StructType( 618 [daf.server_result.type_signature.parameter[1]] 619 ) 620 aggregate_result_vars = variable_helpers.create_vars_for_tff_type( 621 aggregate_result_type, 'intermediate_update' 622 ) 623 aggregate_result_tensor_specs = tf.nest.map_structure( 624 variable_helpers.tensorspec_from_var, aggregate_result_vars 625 ) 626 aggregate_result = checkpoint_utils.restore_tensors_from_savepoint( 627 aggregate_result_tensor_specs, 628 server_result_aggregate_result_input_filepath_placeholder, 629 ) 630 631 # Perform the server_result step. 632 server_state_values, server_output_values = graph_helpers.import_tensorflow( 633 'server_result', 634 tff.framework.ConcreteComputation.from_building_block( 635 tff.backends.mapreduce.consolidate_and_extract_local_processing( 636 daf.server_result.to_building_block(), grappler_config 637 ) 638 ), 639 (intermediate_state, aggregate_result), 640 split_outputs=True, 641 ) 642 643 # Create checkpoints storing the updated server state. 644 server_state_save_op = tensor_utils.save( 645 filename=server_result_server_state_output_filepath_placeholder, 646 tensor_names=variable_helpers.variable_names_from_type( 647 daf.server_result.type_signature.result[0], 'server' 648 ), 649 tensors=server_state_values, 650 name='save_server_state_tensors', 651 ) 652 server_result_target_nodes.append(server_state_save_op.name) 653 654 # Generate the output TensorSpecProtos for the server metrics if some exist. 655 server_output_empty = ( 656 isinstance(daf.server_result.type_signature.result[1], tff.StructType) 657 and not daf.server_result.type_signature.result[1] 658 ) 659 if not server_output_empty: 660 metric_names = variable_helpers.variable_names_from_type( 661 daf.server_result.type_signature.result[1], 'server' 662 ) 663 metric_tensors = [ 664 tf.identity(tensor, name) 665 for tensor, name in zip(server_output_values, metric_names) 666 ] 667 for metric in metric_tensors: 668 server_result_output_tensors.append( 669 proto_helpers.make_tensor_spec_from_tensor( 670 metric 671 ).experimental_as_proto() 672 ) 673 674 # Create the TensorflowSpec messages for the pre-broadcast (server_prepare) 675 # and post-aggregation (server_result) steps. 676 tensorflow_spec_prepare = plan_pb2.TensorflowSpec( 677 input_tensor_specs=[ 678 proto_helpers.make_tensor_spec_from_tensor(t).experimental_as_proto() 679 for t in server_prepare_input_tensors 680 ], 681 target_node_names=server_prepare_target_nodes, 682 ) 683 tensorflow_spec_result = plan_pb2.TensorflowSpec( 684 input_tensor_specs=[ 685 proto_helpers.make_tensor_spec_from_tensor(t).experimental_as_proto() 686 for t in server_result_input_tensors 687 ], 688 output_tensor_specs=server_result_output_tensors, 689 target_node_names=server_result_target_nodes, 690 ) 691 692 # Create the IORouter messages for the pre-broadcast (server_prepare) and 693 # post-aggregation (server_result) steps. 694 server_prepare_io_router = plan_pb2.ServerPrepareIORouter( 695 prepare_server_state_input_filepath_tensor_name=server_prepare_server_state_input_filepath_placeholder.name, 696 prepare_output_filepath_tensor_name=server_prepare_output_filepath_placeholder.name, 697 prepare_intermediate_state_output_filepath_tensor_name=server_prepare_intermediate_state_output_filepath_placeholder.name, 698 ) 699 server_result_io_router = plan_pb2.ServerResultIORouter( 700 result_intermediate_state_input_filepath_tensor_name=server_result_intermediate_state_input_filepath_placeholder.name, 701 result_aggregate_result_input_filepath_tensor_name=server_result_aggregate_result_input_filepath_placeholder.name, 702 result_server_state_output_filepath_tensor_name=server_result_server_state_output_filepath_placeholder.name, 703 ) 704 705 server_phase_v2 = plan_pb2.ServerPhaseV2( 706 tensorflow_spec_prepare=tensorflow_spec_prepare, 707 prepare_router=server_prepare_io_router, 708 aggregations=aggregations, 709 tensorflow_spec_result=tensorflow_spec_result, 710 result_router=server_result_io_router, 711 ) 712 713 return ( 714 server_prepare_graph.as_graph_def(), 715 server_result_graph.as_graph_def(), 716 server_phase_v2, 717 ) 718 719 720def _build_server_graph( 721 mrf: tff.backends.mapreduce.MapReduceForm, 722 broadcast_tff_type: variable_helpers.AllowedTffTypes, 723 is_broadcast_empty: bool, 724 flattened_bitwidths: list[int], 725 flattened_max_inputs: list[int], 726 flattened_moduli: list[int], 727 write_metrics_to_checkpoint: bool = True, 728 additional_checkpoint_metadata_var_fn: Optional[ 729 Callable[[tff.StructType, tff.StructType, bool], list[tf.Variable]] 730 ] = None, 731 experimental_client_update_format: checkpoint_type.CheckpointFormatType = checkpoint_type.CheckpointFormatType.TF1_SAVE_SLICES, 732) -> tuple[ 733 tf.compat.v1.GraphDef, 734 plan_pb2.CheckpointOp, 735 plan_pb2.ServerPhase, 736 list[tf.TensorSpec], 737]: 738 """Builds the `tf.Graph` that will run on the server. 739 740 Args: 741 mrf: A `MapReduceForm` object containing the different computations to 742 combine into a single server graph. 743 broadcast_tff_type: A `tff.Type` object that specifies the tensors in the 744 model that are broadcasted and aggregated. 745 is_broadcast_empty: boolean indicating whether the broadcasted value from 746 the server was initially empty. 747 flattened_bitwidths: The `bitwidth`s for the tensors that will be aggregated 748 using `bitwidth`-based secure summation. 749 flattened_max_inputs: The max_input`s for the tensors that will be 750 aggregated using `max_input`-based secure summation. 751 flattened_moduli: The `modulus`s for the tensors that will be aggregated 752 using modular secure summation. 753 write_metrics_to_checkpoint: If False, revert to legacy behavior where 754 metrics values were handled by post-processing separate from the outputted 755 checkpoint. Regardless, they will additionally continue to be written to 756 recordio and accumulator checkpoints as defined by the Plan proto. 757 additional_checkpoint_metadata_var_fn: An optional method that takes in a 758 server state type, a server metrics type, and a boolean determining 759 whether to revert to legacy metrics behavior to produce additional 760 metadata variables. 761 experimental_client_update_format: Determines how the client update will be 762 interpreted. The value has to match experimental_checkpoint_write argument 763 of the _build_client_graph_with_tensorflow_spec call. 764 765 Returns: 766 A `tuple` containing the following (in order): 767 - A server `tf.GraphDef`, 768 - A server checkpoint, 769 - A server phase proto message, and 770 - A list of `tf.TensorSpec`s for the broadcasted values. 771 """ 772 ( 773 simpleagg_update_type, 774 secure_sum_bitwidth_update_type, 775 secure_sum_update_type, 776 secure_modular_sum_update_type, 777 ) = mrf.work.type_signature.result 778 with tf.Graph().as_default() as server_graph: 779 # Creates all server-side variables and savepoints for both the coordinator 780 # and the intermediate aggregators. 781 # server_state_type will be a SERVER-placed federated type. 782 server_state_type, server_metrics_type = mrf.type_signature.result 783 assert server_state_type.is_federated(), server_state_type 784 assert server_state_type.placement == tff.SERVER, server_state_type 785 # server_metrics_type can be a tff.FederatedType or a structure containing 786 # tff.FederatedTypes. 787 if isinstance(server_metrics_type, tff.FederatedType): 788 # We need to check for server metrics without the placement so 789 # tff.structure.flatten works correctly. 790 has_server_metrics = bool( 791 tff.structure.flatten(server_metrics_type.member) 792 ) 793 else: 794 has_server_metrics = bool(tff.structure.flatten(server_metrics_type)) 795 if isinstance(server_metrics_type, tff.TensorType) or ( 796 isinstance(server_metrics_type, tff.FederatedType) 797 and isinstance(server_metrics_type.member, tff.TensorType) 798 ): 799 # Single tensor; must be wrapped inside of a NamedTuple for proper 800 # variable initialization. 801 server_metrics_type = tff.StructType([server_metrics_type]) 802 ( 803 server_state_vars, 804 server_metrics_vars, 805 metadata_vars, 806 server_savepoint, 807 ) = checkpoint_utils.create_server_checkpoint_vars_and_savepoint( 808 server_state_type=server_state_type, 809 server_metrics_type=server_metrics_type, 810 write_metrics_to_checkpoint=write_metrics_to_checkpoint, 811 additional_checkpoint_metadata_var_fn=( 812 additional_checkpoint_metadata_var_fn 813 ), 814 ) 815 816 # TODO(team): Switch to `tf.save()` in lieu of savers to avoid the 817 # need to create client variables on the server. 818 client_vars_on_server, write_client = ( 819 checkpoint_utils.create_state_vars_and_savepoint( 820 broadcast_tff_type, 'client' 821 ) 822 ) 823 824 secure_sum_update_types = [ 825 secure_sum_bitwidth_update_type, 826 secure_sum_update_type, 827 secure_modular_sum_update_type, 828 ] 829 combined_intermediate_update_type = tff.StructType( 830 [mrf.zero.type_signature.result] + secure_sum_update_types 831 ) 832 833 combined_intermediate_update_vars, write_intermediate_update = ( 834 checkpoint_utils.create_state_vars_and_savepoint( 835 combined_intermediate_update_type, 'intermediate_update' 836 ) 837 ) 838 num_simpleagg_vars = len(combined_intermediate_update_vars) - len( 839 tff.structure.flatten(tff.to_type(secure_sum_update_types)) 840 ) 841 intermediate_update_vars = combined_intermediate_update_vars[ 842 :num_simpleagg_vars 843 ] 844 secagg_intermediate_update_vars = combined_intermediate_update_vars[ 845 num_simpleagg_vars: 846 ] 847 848 read_secagg_update = _read_secagg_update_from_sidechannel_into_vars( 849 secagg_intermediate_update_vars=secagg_intermediate_update_vars, 850 secure_sum_bitwidth_update_type=secure_sum_bitwidth_update_type, 851 bitwidths=flattened_bitwidths, 852 secure_sum_update_type=secure_sum_update_type, 853 max_inputs=flattened_max_inputs, 854 secure_modular_sum_update_type=secure_modular_sum_update_type, 855 moduli=flattened_moduli, 856 ) 857 858 combined_aggregated_update_vars, write_accumulators = ( 859 checkpoint_utils.create_state_vars_and_savepoint( 860 combined_intermediate_update_type, 'aggregated_update' 861 ) 862 ) 863 aggregated_update_vars = combined_aggregated_update_vars[ 864 :num_simpleagg_vars 865 ] 866 secagg_aggregated_update_vars = combined_aggregated_update_vars[ 867 num_simpleagg_vars: 868 ] 869 870 # Throws in the initializer for all state variables, to be executed prior 871 # to restoring the savepoint. Run this variable initializer prior to 872 # restoring from the savepoint to allow the vars to be overwritten by the 873 # savepoint in this case, and so they do not get re-executed after being 874 # overwritten. Also include the metrics vars here in case the execution 875 # environment wants to read those in. 876 server_vars_initializer = tf.compat.v1.variables_initializer( 877 server_state_vars + metadata_vars + server_metrics_vars, 878 'initialize_server_state_and_non_state_vars', 879 ) 880 server_savepoint.before_restore_op = server_vars_initializer.name 881 882 # In graph mode, TensorFlow does not allow creating a 883 # `tf.compat.v1.train.Saver` when there are no variables. As a result, 884 # calling `create_state_vars_and_savepoint` below will fail when there are 885 # no SimpleAgg variables (e.g., all results are aggregated via SecAgg). In 886 # this case, there are no client checkpoints, and hence, no need to populate 887 # the `read_update` field. 888 if num_simpleagg_vars > 0: 889 # Run the initializer for update vars prior to restoring the client update 890 update_vars, read_update = ( 891 checkpoint_utils.create_state_vars_and_savepoint( 892 simpleagg_update_type, artifact_constants.UPDATE 893 ) 894 ) 895 update_vars_initializer = tf.compat.v1.variables_initializer( 896 update_vars, 'initialize_update_vars' 897 ) 898 if ( 899 experimental_client_update_format 900 == checkpoint_type.CheckpointFormatType.APPEND_SLICES_MERGE_READ 901 ): 902 graph = tf.compat.v1.get_default_graph() 903 checkpoint_pl = graph.get_tensor_by_name( 904 read_update.saver_def.filename_tensor_name 905 ) 906 merge_checkpoint_slices = append_slices.merge_appended_slices( 907 checkpoint_pl, 'merge_checkpoint_slices' 908 ) 909 init_merge = tf.group(update_vars_initializer, merge_checkpoint_slices) 910 read_update.before_restore_op = init_merge.name 911 else: 912 read_update.before_restore_op = update_vars_initializer.name 913 else: 914 # Create a empty list for `update_vars` when there are no SimpleAgg 915 # variables, to be compatible with the `accumulated_values` defined below. 916 update_vars = [] 917 918 # Copy the intermediate aggregator's update saver for use on coordinator. 919 read_intermediate_update = plan_pb2.CheckpointOp() 920 read_intermediate_update.CopyFrom(write_intermediate_update) 921 922 # Condition all the remaining logic on the variable initializers, since 923 # intermediate aggregators are supposed to be stateless (no savepoint, and 924 # therefore no `before_restore_op`, either). 925 with tf.control_dependencies( 926 [ 927 tf.compat.v1.variables_initializer( 928 (intermediate_update_vars + aggregated_update_vars), 929 'initialize_accumulator_vars', 930 ) 931 ] 932 ): 933 # Embeds the `zero` logic and hooks it up to `after_restore_op` of 934 # server's checkpointed state (shared between the coordinator and the 935 # intermediate aggregators). The zeros get assigned to 936 # `intermediate_update_vars` and to the `aggregated_update_vars` at the 937 # very beginning, right after restoring from `server_savepoint`. 938 zero_values = graph_helpers.import_tensorflow('zero', mrf.zero) 939 assign_zero_ops = tf.nest.map_structure( 940 lambda variable, value: variable.assign(value), 941 intermediate_update_vars, 942 zero_values, 943 ) + tf.nest.map_structure( 944 lambda variable, value: variable.assign(value), 945 aggregated_update_vars, 946 zero_values, 947 ) 948 949 # Embeds the `prepare` logic, and hooks it up to `before_save_op` of 950 # client state (to be checkpointed and sent to the clients at the 951 # beginning of the round by the central coordinator). 952 with tf.control_dependencies( 953 [ 954 tf.compat.v1.variables_initializer( 955 client_vars_on_server, 'initialize_client_vars_on_server' 956 ) 957 ] 958 ): 959 # Configure the session token for `write_client` so that the `prepare` 960 # operation may be fed the callback ID for the `SaveSlices` op 961 # (necessary for plans containing `federated_select`). 962 write_client_session_token = tf.compat.v1.placeholder_with_default( 963 input='', shape=(), name='write_client_session_token' 964 ) 965 prepared_values = graph_helpers.import_tensorflow( 966 'prepare', 967 mrf.prepare, 968 server_state_vars, 969 session_token_tensor=write_client_session_token, 970 ) 971 if is_broadcast_empty: 972 # If the broadcast was empty, don't assigning the sample incoming 973 # tf.int32 to anything. 974 client_state_assign_ops = [tf.no_op()] 975 else: 976 client_state_assign_ops = tf.nest.map_structure( 977 lambda variable, tensor: variable.assign(tensor), 978 client_vars_on_server, 979 prepared_values, 980 ) 981 write_client.before_save_op = tf.group(*client_state_assign_ops).name 982 write_client.session_token_tensor_name = write_client_session_token.name 983 984 # Embeds the `accumulate` logic, and hooks up the assignment of a client 985 # update to the intermediate update to `aggregate_into_accumulators_op`. 986 accumulated_values = graph_helpers.import_tensorflow( 987 'accumulate', mrf.accumulate, (intermediate_update_vars, update_vars) 988 ) 989 intermediate_update_assign_ops = tf.nest.map_structure( 990 lambda variable, tensor: variable.assign(tensor), 991 intermediate_update_vars, 992 accumulated_values, 993 ) 994 aggregate_into_accumulators_op = tf.group( 995 *intermediate_update_assign_ops 996 ).name 997 998 secagg_aggregated_update_init = tf.compat.v1.variables_initializer( 999 secagg_aggregated_update_vars 1000 ) 1001 1002 # Reset the accumulators in `phase_init_op`, after variable initializers 1003 # and after restoring from the savepoint. 1004 phase_init_op = tf.group( 1005 *(assign_zero_ops + [secagg_aggregated_update_init]) 1006 ).name 1007 1008 # Embeds the `merge` logic, and hooks up the assignment of an intermediate 1009 # update to the top-level aggregate update at the coordinator to 1010 # `intermediate_aggregate_into_accumulators_op`. 1011 merged_values = graph_helpers.import_tensorflow( 1012 'merge', mrf.merge, (aggregated_update_vars, intermediate_update_vars) 1013 ) 1014 aggregated_update_assign_ops = tf.nest.map_structure( 1015 lambda variable, tensor: variable.assign(tensor), 1016 aggregated_update_vars, 1017 merged_values, 1018 ) 1019 1020 secagg_aggregated_update_ops = _merge_secagg_vars( 1021 secure_sum_bitwidth_update_type, 1022 secure_sum_update_type, 1023 secure_modular_sum_update_type, 1024 flattened_moduli, 1025 secagg_aggregated_update_vars, 1026 secagg_intermediate_update_vars, 1027 ) 1028 1029 intermediate_aggregate_into_accumulators_op = tf.group( 1030 *(aggregated_update_assign_ops + secagg_aggregated_update_ops) 1031 ).name 1032 1033 # Embeds the `report` and `update` logic, and hooks up the assignments of 1034 # the results of the final update to the server state and metric vars, to 1035 # be triggered by `apply_aggregrated_updates_op`. 1036 simpleagg_reported_values = graph_helpers.import_tensorflow( 1037 'report', mrf.report, aggregated_update_vars 1038 ) 1039 1040 # NOTE: In MapReduceForm, the `update` method takes in the simpleagg vars 1041 # and SecAgg vars as a tuple of two separate lists. However, here, as 1042 # above, we concatenate the simpleagg values and the secure values into a 1043 # single list. This mismatch is not a problem because the variables are all 1044 # flattened either way when traveling in and out of the tensorflow graph. 1045 combined_update_vars = ( 1046 simpleagg_reported_values + secagg_aggregated_update_vars 1047 ) 1048 new_server_state_values, server_metrics_values = ( 1049 graph_helpers.import_tensorflow( 1050 artifact_constants.UPDATE, 1051 mrf.update, 1052 (server_state_vars, combined_update_vars), 1053 split_outputs=True, 1054 ) 1055 ) 1056 1057 assign_server_state_ops = tf.nest.map_structure( 1058 lambda variable, tensor: variable.assign(tensor), 1059 server_state_vars, 1060 new_server_state_values, 1061 ) 1062 assign_non_state_ops = tf.nest.map_structure( 1063 lambda variable, value: variable.assign(value), 1064 server_metrics_vars, 1065 server_metrics_values, 1066 ) 1067 all_assign_ops = assign_server_state_ops + assign_non_state_ops 1068 apply_aggregrated_updates_op = tf.group(*all_assign_ops).name 1069 1070 # Constructs the metadata for server metrics to be included in the plan. 1071 server_metrics = [ 1072 proto_helpers.make_metric(v, artifact_constants.SERVER_STATE_VAR_PREFIX) 1073 for v in server_metrics_vars 1074 ] 1075 1076 server_phase_kwargs = collections.OrderedDict( 1077 phase_init_op=phase_init_op, 1078 write_client_init=write_client, 1079 read_aggregated_update=read_secagg_update, 1080 write_intermediate_update=write_intermediate_update, 1081 read_intermediate_update=read_intermediate_update, 1082 intermediate_aggregate_into_accumulators_op=( 1083 intermediate_aggregate_into_accumulators_op 1084 ), 1085 write_accumulators=write_accumulators, 1086 apply_aggregrated_updates_op=apply_aggregrated_updates_op, 1087 metrics=server_metrics, 1088 ) 1089 1090 if num_simpleagg_vars > 0: 1091 # The `read_update` loads SimpleAgg updates from client checkpoints. The 1092 # `aggregate_into_accumulators_op` aggregates SimpleAgg data after loading 1093 # the client updates. No need to populate the two fields if there are no 1094 # SimpleAgg variables (e.g., if all results are aggregated via SecAgg). 1095 server_phase_kwargs['read_update'] = read_update 1096 server_phase_kwargs['aggregate_into_accumulators_op'] = ( 1097 aggregate_into_accumulators_op 1098 ) 1099 1100 server_phase = plan_pb2.ServerPhase(**server_phase_kwargs) 1101 1102 broadcasted_tensor_specs = tf.nest.map_structure( 1103 variable_helpers.tensorspec_from_var, client_vars_on_server 1104 ) 1105 server_graph_def = server_graph.as_graph_def() 1106 1107 if write_metrics_to_checkpoint: 1108 server_graph_def = _redirect_save_saver_to_restore_saver_placeholder( 1109 server_graph_def 1110 ) 1111 1112 return ( 1113 server_graph_def, 1114 server_savepoint, 1115 server_phase, 1116 broadcasted_tensor_specs, 1117 ) 1118 1119 1120def _redirect_save_saver_to_restore_saver_placeholder( 1121 graph_def: tf.compat.v1.GraphDef, 1122) -> tf.compat.v1.GraphDef: 1123 """Updates save Saver's savepoint to point to restore Saver's placeholder. 1124 1125 NOTE: mutates the GraphDef passed in and returns the mutated GraphDef. 1126 1127 When we created the server_savepoint Saver when we are outputting all of 1128 the metrics to the output checkpoint as well, we set different nodes for 1129 saving and restoring so that we could save state + metrics and restore 1130 just state. However, the only way to do so was to make two Savers and 1131 splice them together. This meant that the save and restore operations 1132 depend on two different placeholders for the checkpoint filename. To 1133 avoid server changes that pass the same checkpoint name in twice to both 1134 placeholders, we make a few changes to the server GraphDef so that the 1135 saving op connects back to the placeholder for the restore operation. 1136 Once this is done, the original save placeholder node will still exist in 1137 the graph, but it won't be used by any part of the graph that connects to 1138 an operation we care about. 1139 1140 Args: 1141 graph_def: A `tf.compat.v1.GraphDef` to mutate. 1142 1143 Returns: 1144 The mutated `tf.compat.v1.GraphDef` that was passed in as graph_def. 1145 """ 1146 old_const_node = f'{checkpoint_utils.SAVE_SERVER_SAVEPOINT_NAME}/Const' 1147 new_const_node = ( 1148 f'{artifact_constants.SERVER_STATE_VAR_PREFIX}_savepoint/Const' 1149 ) 1150 nodes_to_change = [ 1151 f'{checkpoint_utils.SAVE_SERVER_SAVEPOINT_NAME}/save', 1152 f'{checkpoint_utils.SAVE_SERVER_SAVEPOINT_NAME}/control_dependency', 1153 f'{checkpoint_utils.SAVE_SERVER_SAVEPOINT_NAME}/RestoreV2', 1154 ] 1155 num_changed_nodes = 0 1156 for node in graph_def.node: 1157 if node.name in nodes_to_change: 1158 input_index = 0 1159 for input_index, input_node in enumerate(node.input): 1160 if input_node == old_const_node: 1161 node.input[input_index] = new_const_node 1162 break 1163 assert input_index != len( 1164 node.input 1165 ), 'Missed input arg in saver GraphDef rewriting.' 1166 num_changed_nodes = num_changed_nodes + 1 1167 if num_changed_nodes == len(nodes_to_change): 1168 # Once we've changed all of the callsites, we stop. 1169 return graph_def 1170 return graph_def 1171 1172 1173def _build_client_graph_with_tensorflow_spec( 1174 client_work_comp: tff.Computation, 1175 dataspec, 1176 broadcasted_tensor_specs: Iterable[tf.TensorSpec], 1177 is_broadcast_empty: bool, 1178 *, 1179 experimental_checkpoint_write: checkpoint_type.CheckpointFormatType = checkpoint_type.CheckpointFormatType.TF1_SAVE_SLICES, 1180) -> tuple[tf.compat.v1.GraphDef, plan_pb2.ClientPhase]: 1181 """Builds the client graph and ClientPhase with TensorflowSpec populated. 1182 1183 This function builds a client phase with tensorflow specs proto. 1184 1185 Args: 1186 client_work_comp: A `tff.Computation` that represents the TensorFlow logic 1187 run on-device. 1188 dataspec: Either an instance of `data_spec.DataSpec` or a nested structure 1189 of these that matches the structure of the first element of the input to 1190 `client_work_comp`. 1191 broadcasted_tensor_specs: A list of `tf.TensorSpec` containing the name and 1192 dtype of the variables arriving via the broadcast checkpoint. 1193 is_broadcast_empty: A boolean indicating whether the MapReduce form 1194 initially called for an empty broadcast. In this case the 1195 broadcasted_tensor_specs will contain a single tf.int32, but it will be 1196 ignored. 1197 experimental_checkpoint_write: Determines the format of the final client 1198 update checkpoint. The value affects required operations and might have 1199 performance implications. 1200 1201 Returns: 1202 A `tuple` of the client TensorFlow GraphDef and the client phase protocol 1203 message. 1204 1205 Raises: 1206 SecureAggregationTensorShapeError: If SecAgg tensors do not have all 1207 dimensions of their shape fully defined. 1208 ValueError: If any of the arguments are found to be in an unexpected form. 1209 """ 1210 if ( 1211 not isinstance(client_work_comp.type_signature.parameter, tff.StructType) 1212 or len(client_work_comp.type_signature.parameter) < 1 1213 ): 1214 raise ValueError( 1215 'client_work_comp.type_signature.parameter should be a ' 1216 '`tff.StructType` with length >= 1, but found: {p}.'.format( 1217 p=client_work_comp.type_signature.parameter 1218 ) 1219 ) 1220 1221 if ( 1222 not isinstance(client_work_comp.type_signature.result, tff.StructType) 1223 or len(client_work_comp.type_signature.result) != 4 1224 ): 1225 raise ValueError( 1226 'client_work_comp.type_signature.result should be a ' 1227 '`tff.StructType` with length == 4, but found: {r}.'.format( 1228 r=client_work_comp.type_signature.result 1229 ) 1230 ) 1231 1232 ( 1233 simpleagg_update_type, 1234 secure_sum_bitwidth_update_type, 1235 secure_sum_update_type, 1236 secure_modular_sum_update_type, 1237 ) = client_work_comp.type_signature.result 1238 1239 # A list of tensors that will be passed into TensorFlow, corresponding to 1240 # `plan_pb2.ClientPhase.tensorflow_spec.input_tensor_specs`. Note that the 1241 # dataset token is excluded from this list. In general, this list should 1242 # include the filepath placeholder tensors for the input checkpoint file and 1243 # output checkpoint file. 1244 input_tensors = [] 1245 1246 # A list of tensor specs that should be fetched from TensorFlow, corresponding 1247 # to `plan_pb2.ClientPhase.tensorflow_spec.output_tensor_specs`. In general, 1248 # this list should include the tensors that are not in the output checkpoint 1249 # file, such as secure aggregation tensors. 1250 output_tensor_specs = [] 1251 1252 # A list of node names in the client graph that should be executed but no 1253 # output returned, corresponding to 1254 # `plan_pb2.ClientPhase.tensorflow_spec.target_node_names`. In general, this 1255 # list should include the op that creates the output checkpoint file. 1256 target_nodes = [] 1257 with tf.Graph().as_default() as client_graph: 1258 input_filepath_placeholder = None 1259 if not is_broadcast_empty: 1260 input_filepath_placeholder = tf.compat.v1.placeholder( 1261 name='input_filepath', shape=(), dtype=tf.string 1262 ) 1263 weights_from_server = checkpoint_utils.restore_tensors_from_savepoint( 1264 broadcasted_tensor_specs, input_filepath_placeholder 1265 ) 1266 input_tensors.append(input_filepath_placeholder) 1267 else: 1268 weights_from_server = [] 1269 1270 # Add the custom Dataset ops to the graph. 1271 token_placeholder, data_values, example_selector_placeholders = ( 1272 graph_helpers.embed_data_logic( 1273 client_work_comp.type_signature.parameter[0], dataspec 1274 ) 1275 ) 1276 1277 # Embed the graph coming from TFF into the client work graph. 1278 combined_update_tensors = graph_helpers.import_tensorflow( 1279 'work', 1280 client_work_comp, 1281 (data_values, weights_from_server), 1282 split_outputs=False, 1283 session_token_tensor=token_placeholder, 1284 ) # pytype: disable=wrong-arg-types 1285 1286 num_simpleagg_tensors = len(tff.structure.flatten(simpleagg_update_type)) 1287 simpleagg_tensors = combined_update_tensors[:num_simpleagg_tensors] 1288 secagg_tensors = combined_update_tensors[num_simpleagg_tensors:] 1289 1290 # For tensors aggregated by secagg, we make sure the tensor names are 1291 # aligned in both client and sever graph by getting the names from the same 1292 # method. 1293 secagg_tensor_names = [] 1294 secagg_tensor_types = [] 1295 for uri, update_type in [ 1296 (SECURE_SUM_BITWIDTH_URI, secure_sum_bitwidth_update_type), 1297 (SECURE_SUM_URI, secure_sum_update_type), 1298 (SECURE_MODULAR_SUM_URI, secure_modular_sum_update_type), 1299 ]: 1300 secagg_tensor_names += variable_helpers.get_shared_secagg_tensor_names( 1301 uri, update_type 1302 ) 1303 secagg_tensor_types += tff.structure.flatten(update_type) 1304 1305 secagg_tensors = [ 1306 tf.identity(tensor, name=tensor_utils.bare_name(name)) 1307 for tensor, name in zip(secagg_tensors, secagg_tensor_names) 1308 ] 1309 for t, type_spec in zip(secagg_tensors, secagg_tensor_types): 1310 secagg_tensor_spec = proto_helpers.make_tensor_spec_from_tensor( 1311 t, shape_hint=type_spec.shape 1312 ) 1313 output_tensor_specs.append(secagg_tensor_spec.experimental_as_proto()) 1314 1315 # Verify that SecAgg Tensors have all dimension fully defined. 1316 for tensor_spec in output_tensor_specs: 1317 if not tf.TensorShape(tensor_spec.shape).is_fully_defined(): 1318 raise SecureAggregationTensorShapeError( 1319 '`TensorflowSpec.output_tensor_specs` has unknown dimension.' 1320 ) 1321 1322 output_filepath_placeholder = None 1323 if simpleagg_tensors: 1324 output_filepath_placeholder = tf.compat.v1.placeholder( 1325 dtype=tf.string, shape=(), name='output_filepath' 1326 ) 1327 simpleagg_variable_names = variable_helpers.variable_names_from_type( 1328 simpleagg_update_type, name=artifact_constants.UPDATE 1329 ) 1330 if experimental_checkpoint_write in [ 1331 checkpoint_type.CheckpointFormatType.APPEND_SLICES_MERGE_WRITE, 1332 checkpoint_type.CheckpointFormatType.APPEND_SLICES_MERGE_READ, 1333 ]: 1334 delete_op = delete_file.delete_file(output_filepath_placeholder) 1335 with tf.control_dependencies([delete_op]): 1336 append_ops = [] 1337 for tensor_name, tensor in zip( 1338 simpleagg_variable_names, simpleagg_tensors 1339 ): 1340 append_ops.append( 1341 tensor_utils.save( 1342 filename=output_filepath_placeholder, 1343 tensor_names=[tensor_name], 1344 tensors=[tensor], 1345 save_op=append_slices.append_slices, 1346 ) 1347 ) 1348 if ( 1349 experimental_checkpoint_write 1350 == checkpoint_type.CheckpointFormatType.APPEND_SLICES_MERGE_WRITE 1351 ): 1352 with tf.control_dependencies(append_ops): 1353 save_op = append_slices.merge_appended_slices( 1354 filename=output_filepath_placeholder 1355 ) 1356 else: 1357 # APPEND_SLICES_MERGE_READ 1358 save_op = tf.group(*append_ops) 1359 1360 elif ( 1361 experimental_checkpoint_write 1362 == checkpoint_type.CheckpointFormatType.TF1_SAVE_SLICES 1363 ): 1364 save_op = tensor_utils.save( 1365 filename=output_filepath_placeholder, 1366 tensor_names=simpleagg_variable_names, 1367 tensors=simpleagg_tensors, 1368 name='save_client_update_tensors', 1369 ) 1370 else: 1371 raise NotImplementedError( 1372 f'Unsupported CheckpointFormatType {experimental_checkpoint_write}.' 1373 ) 1374 input_tensors.append(output_filepath_placeholder) 1375 target_nodes.append(save_op.name) 1376 1377 tensorflow_spec = plan_pb2.TensorflowSpec() 1378 if token_placeholder is not None: 1379 tensorflow_spec.dataset_token_tensor_name = token_placeholder.name 1380 if input_tensors: 1381 tensorflow_spec.input_tensor_specs.extend( 1382 tf.TensorSpec.from_tensor(t, name=t.name).experimental_as_proto() 1383 for t in input_tensors 1384 ) 1385 if output_tensor_specs: 1386 tensorflow_spec.output_tensor_specs.extend(output_tensor_specs) 1387 if target_nodes: 1388 tensorflow_spec.target_node_names.extend(target_nodes) 1389 if example_selector_placeholders: 1390 for placeholder in example_selector_placeholders: 1391 # Generating the default TensorProto will create a TensorProto with an 1392 # DT_INVALID DType. This identifies that there is a placeholder that is 1393 # needed. In order to have the Plan proto be completely runnable, the 1394 # value will need to be filled in with a real TensorProto that matches 1395 # the shape/type of the expected input. 1396 tensorflow_spec.constant_inputs[placeholder.name].dtype = 0 1397 1398 io_router = plan_pb2.FederatedComputeIORouter() 1399 if input_filepath_placeholder is not None: 1400 io_router.input_filepath_tensor_name = input_filepath_placeholder.name 1401 if output_filepath_placeholder is not None: 1402 io_router.output_filepath_tensor_name = output_filepath_placeholder.name 1403 for secagg_tensor in secagg_tensors: 1404 io_router.aggregations[secagg_tensor.name].CopyFrom( 1405 plan_pb2.AggregationConfig( 1406 secure_aggregation=plan_pb2.SecureAggregationConfig() 1407 ) 1408 ) 1409 1410 return client_graph.as_graph_def(), plan_pb2.ClientPhase( 1411 tensorflow_spec=tensorflow_spec, federated_compute=io_router 1412 ) 1413 1414 1415def _build_client_phase_with_example_query_spec( 1416 client_work_comp: tff.Computation, 1417 example_query_spec: plan_pb2.ExampleQuerySpec, 1418) -> plan_pb2.ClientPhase: 1419 """Builds the ClientPhase with `ExampleQuerySpec` populated. 1420 1421 Args: 1422 client_work_comp: A `tff.Computation` that represents the TensorFlow logic 1423 run on-device. 1424 example_query_spec: Field containing output vector information for client 1425 example query. The output vector names listed in the spec are expected to 1426 be consistent with the output names we would produce in the 1427 `MapReduceForm` client work computation, if we were to build a TF-based 1428 plan from that `MapReduceForm`. 1429 1430 Returns: 1431 A client phase part of the federated protocol. 1432 """ 1433 expected_vector_names = set( 1434 variable_helpers.variable_names_from_type( 1435 client_work_comp.type_signature.result[0], artifact_constants.UPDATE 1436 ) 1437 ) 1438 used_names = set() 1439 io_router = plan_pb2.FederatedExampleQueryIORouter() 1440 for example_query in example_query_spec.example_queries: 1441 vector_names = set(example_query.output_vector_specs.keys()) 1442 if not all([name in expected_vector_names for name in vector_names]): 1443 raise ValueError( 1444 'Found unexpected vector names in supplied `example_query_spec`. ' 1445 f'Expected names: {expected_vector_names}. ' 1446 f'Found unexpected names: {vector_names-expected_vector_names}.' 1447 ) 1448 1449 if any([name in used_names for name in vector_names]): 1450 raise ValueError( 1451 'Duplicate vector names found in supplied `example_query_spec`. ' 1452 f'Duplicates: {vector_names.intersection(used_names)}' 1453 ) 1454 1455 used_names.update(vector_names) 1456 1457 for vector_name in vector_names: 1458 io_router.aggregations[vector_name].CopyFrom( 1459 plan_pb2.AggregationConfig( 1460 tf_v1_checkpoint_aggregation=plan_pb2.TFV1CheckpointAggregation() 1461 ) 1462 ) 1463 1464 if used_names != expected_vector_names: 1465 raise ValueError( 1466 'Not all expected vector names were in supplied `example_query_spec`.' 1467 f' Expected names: {expected_vector_names}. Names not present in' 1468 f' `example_query_spec`: {expected_vector_names-vector_names}' 1469 ) 1470 return plan_pb2.ClientPhase( 1471 example_query_spec=example_query_spec, federated_example_query=io_router 1472 ) 1473 1474 1475def build_plan( 1476 mrf: tff.backends.mapreduce.MapReduceForm, 1477 daf: Optional[tff.backends.mapreduce.DistributeAggregateForm] = None, 1478 dataspec: Optional[data_spec.NestedDataSpec] = None, 1479 example_query_spec: Optional[plan_pb2.ExampleQuerySpec] = None, 1480 grappler_config: Optional[tf.compat.v1.ConfigProto] = None, 1481 additional_checkpoint_metadata_var_fn: Optional[ 1482 Callable[[tff.StructType, tff.StructType, bool], list[tf.Variable]] 1483 ] = None, 1484 experimental_client_checkpoint_write: checkpoint_type.CheckpointFormatType = checkpoint_type.CheckpointFormatType.TF1_SAVE_SLICES, 1485 generate_server_phase_v2: bool = False, 1486 write_metrics_to_checkpoint: bool = True, 1487) -> plan_pb2.Plan: 1488 """Constructs an instance of `plan_pb2.Plan` given a `MapReduceForm` instance. 1489 1490 Plans generated by this method are executable, but a number of features have 1491 yet to be implemented. 1492 1493 These include: 1494 1495 - Setting metrics' `stat_name` field based on externally-supplied metadata, 1496 such as that from the model stampers. Currently, these names are based on 1497 the names of TensorFlow variables, which in turn are based on the TFF 1498 type signatures. 1499 1500 - Populating the client `example_selector` field. Currently not set. 1501 1502 - Populating client-side `savepoint`. Currently not set. 1503 1504 - Populating the plan's `tensorflow_config_proto`. Currently not set. 1505 1506 - Setting a field in the plan that represets a token to drive the custom op 1507 that iplements the client-side dataset. There is no such field in the plan 1508 at the time of this writing. 1509 1510 - Populating plan fields related to secure aggregation and side channels, 1511 such as the `read_aggregated_update` checkpoint op. 1512 1513 Args: 1514 mrf: An instance of `tff.backends.mapreduce.MapReduceForm`. 1515 daf: An instance of `tff.backends.mapreduce.DistributeAggregateForm`. 1516 dataspec: If provided, either an instance of `data_spec.DataSpec` or a 1517 nested structure of these that matches the structure of the first element 1518 of the input to client-side processing computation `mrf.work`. If not 1519 provided and `example_query_spec` is also not provided, then placeholders 1520 are added to the client graph via `embed_data_logic()` and the example 1521 selectors will need to be passed to the client via the `constant_inputs` 1522 part of the `TensorflowSpec`. The constant_inputs field needs to be 1523 populated outside of `build_plan()`. Can only provide one of `dataspec` or 1524 `example_query_spec`. 1525 example_query_spec: An instance of `plan_pb2.ExampleQuerySpec`. If provided 1526 it is assumed a light weight client plan should be constructed. No client 1527 graph will be included in the produced plan object. Instead the generated 1528 plan will have an `ExampleQuerySpec` and `FederatedExampleQueryIORouter`. 1529 Can only supply one of `dataspec` or `example_query_spec`. 1530 grappler_config: The config specifying Grappler optimizations for TFF- 1531 generated graphs. Should be provided if daf is provided. 1532 additional_checkpoint_metadata_var_fn: An optional method that takes in a 1533 server state type, a server metrics type, and a boolean determining 1534 whether to revert to legacy metrics behavior to produce additional 1535 metadata variables. 1536 experimental_client_checkpoint_write: Determines the style of writing of the 1537 client checkpoint (client->server communication). The value affects the 1538 operation used and might have impact on overall task performance. 1539 generate_server_phase_v2: Iff `True`, will produce a ServerPhaseV2 message 1540 in the plan in addition to a ServerPhase message. 1541 write_metrics_to_checkpoint: If False, revert to legacy behavior where 1542 metrics values were handled by post-processing separate from the outputted 1543 checkpoint. Regardless, they will additionally continue to be written to 1544 recordio and accumulator checkpoints as defined by the Plan proto. 1545 1546 Returns: 1547 An instance of `plan_pb2.Plan` corresponding to MapReduce form `mrf`. 1548 1549 Raises: 1550 TypeError: If the arguments are of the wrong types. 1551 ValueError: If any of the arguments are found to be in an unexpected form. 1552 """ 1553 type_checks.check_type(mrf, tff.backends.mapreduce.MapReduceForm, name='mrf') 1554 client_plan_type = ( 1555 ClientPlanType.TENSORFLOW 1556 if example_query_spec is None 1557 else ClientPlanType.EXAMPLE_QUERY 1558 ) 1559 1560 if example_query_spec is not None: 1561 if dataspec is not None: 1562 raise ValueError( 1563 '`example_query_spec` or `dataspec` cannot both be specified.' 1564 ) 1565 1566 with tff.framework.get_context_stack().install( 1567 tff.test.create_runtime_error_context() 1568 ): 1569 is_broadcast_empty = ( 1570 isinstance(mrf.prepare.type_signature.result, tff.StructType) 1571 and not mrf.prepare.type_signature.result 1572 ) 1573 if is_broadcast_empty: 1574 # This MapReduceForm does not send any server state to clients, however we 1575 # need something to satisfy current restrictions from the FCP server. 1576 # Use a placeholder scalar int. 1577 broadcast_tff_type = tff.TensorType(tf.int32) 1578 else: 1579 broadcast_tff_type = mrf.prepare.type_signature.result 1580 1581 # Execute the bitwidths TFF computation using the default TFF executor. 1582 bitwidths, max_inputs, moduli = _compute_secagg_parameters(mrf) 1583 # Note: the variables below are flat lists, even though 1584 # `secure_sum_bitwidth_update_type` 1585 # could potentially represent a large group of nested tensors. In order 1586 # for each var to line up with the appropriate bitwidth, we must also 1587 # flatten the list of bitwidths. 1588 flattened_bitwidths = tff.structure.flatten(bitwidths) 1589 flattened_max_inputs = tff.structure.flatten(max_inputs) 1590 flattened_moduli = tff.structure.flatten(moduli) 1591 1592 ( 1593 server_graph_def, 1594 server_savepoint, 1595 server_phase, 1596 broadcasted_tensor_specs, 1597 ) = _build_server_graph( 1598 mrf, 1599 broadcast_tff_type, 1600 is_broadcast_empty, 1601 flattened_bitwidths, 1602 flattened_max_inputs, 1603 flattened_moduli, 1604 write_metrics_to_checkpoint, 1605 additional_checkpoint_metadata_var_fn, 1606 experimental_client_update_format=experimental_client_checkpoint_write, 1607 ) 1608 1609 if client_plan_type == ClientPlanType.TENSORFLOW: 1610 client_graph_def, client_phase = _build_client_graph_with_tensorflow_spec( 1611 mrf.work, 1612 dataspec, 1613 broadcasted_tensor_specs, 1614 is_broadcast_empty, 1615 experimental_checkpoint_write=experimental_client_checkpoint_write, 1616 ) 1617 elif client_plan_type == ClientPlanType.EXAMPLE_QUERY: 1618 client_phase = _build_client_phase_with_example_query_spec( 1619 mrf.work, example_query_spec 1620 ) 1621 else: 1622 raise ValueError( 1623 f'Unexpected value for `client_plan_type`: {client_plan_type}' 1624 ) 1625 1626 combined_phases = plan_pb2.Plan.Phase( 1627 server_phase=server_phase, client_phase=client_phase 1628 ) 1629 1630 if generate_server_phase_v2: 1631 assert daf 1632 assert grappler_config 1633 (server_graph_def_prepare, server_graph_def_result, server_phase_v2) = ( 1634 _build_server_graphs_from_distribute_aggregate_form( 1635 daf, is_broadcast_empty, grappler_config 1636 ) 1637 ) 1638 combined_phases.server_phase_v2.CopyFrom(server_phase_v2) 1639 1640 plan = plan_pb2.Plan( 1641 version=1, server_savepoint=server_savepoint, phase=[combined_phases] 1642 ) 1643 1644 plan.server_graph_bytes.Pack(server_graph_def) 1645 if client_plan_type == ClientPlanType.TENSORFLOW: 1646 plan.client_graph_bytes.Pack(client_graph_def) 1647 1648 if generate_server_phase_v2: 1649 plan.server_graph_prepare_bytes.Pack(server_graph_def_prepare) 1650 plan.server_graph_result_bytes.Pack(server_graph_def_result) 1651 return plan 1652 1653 1654def build_cross_round_aggregation_execution( 1655 mrf: tff.backends.mapreduce.MapReduceForm, 1656) -> bytes: 1657 """Constructs an instance of `plan_pb2.CrossRoundAggregationExecution`. 1658 1659 Args: 1660 mrf: An instance of `tff.backends.mapreduce.MapReduceForm`. 1661 1662 Returns: 1663 A serialized instance of `plan_pb2.CrossRoundAggregationExecution` for given 1664 `mrf`. 1665 """ 1666 type_checks.check_type(mrf, tff.backends.mapreduce.MapReduceForm, name='mrf') 1667 1668 server_metrics_type = mrf.update.type_signature.result[1] 1669 ( 1670 simpleagg_update_type, 1671 secure_sum_bitwidth_update_type, 1672 secure_sum_update_type, 1673 secure_modular_sum_update_type, 1674 ) = mrf.work.type_signature.result 1675 # We don't ever work directly on `simpleagg_update_type` because client 1676 # updates are transformed by `accumulate` and `merge` before ever being passed 1677 # into cross-round aggregation. 1678 del simpleagg_update_type 1679 simpleagg_merge_type = mrf.merge.type_signature.result 1680 flattened_moduli = tff.structure.flatten(mrf.secure_modular_sum_modulus()) 1681 1682 if not server_metrics_type: 1683 # No metrics to aggregrate; will initialize to no-op. 1684 server_metrics_type = tff.StructType([]) 1685 elif isinstance(server_metrics_type, tff.TensorType): 1686 # Single tensor metric; must be wrapped inside of a NamedTuple for proper 1687 # variable initialiazation. 1688 server_metrics_type = tff.StructType([server_metrics_type]) 1689 combined_aggregated_update_type = tff.StructType([ 1690 simpleagg_merge_type, 1691 secure_sum_bitwidth_update_type, 1692 secure_sum_update_type, 1693 secure_modular_sum_update_type, 1694 ]) 1695 1696 with tf.Graph().as_default() as cross_round_aggregation_graph: 1697 server_state_vars = variable_helpers.create_vars_for_tff_type( 1698 mrf.update.type_signature.parameter[0], 1699 artifact_constants.SERVER_STATE_VAR_PREFIX, 1700 ) 1701 1702 combined_aggregated_update_vars, read_aggregated_update = ( 1703 checkpoint_utils.create_state_vars_and_savepoint( 1704 combined_aggregated_update_type, 'aggregated_update' 1705 ) 1706 ) 1707 1708 num_simpleagg_vars = len(tff.structure.flatten(simpleagg_merge_type)) 1709 1710 aggregated_update_vars = combined_aggregated_update_vars[ 1711 :num_simpleagg_vars 1712 ] 1713 secagg_aggregated_update_vars = combined_aggregated_update_vars[ 1714 num_simpleagg_vars: 1715 ] 1716 1717 # Add a new output for metrics_loader `merge` and `report`. 1718 combined_final_accumulator_vars, read_write_final_accumulators = ( 1719 checkpoint_utils.create_state_vars_and_savepoint( 1720 combined_aggregated_update_type, 'final_accumulators' 1721 ) 1722 ) 1723 1724 final_accumulator_vars = combined_final_accumulator_vars[ 1725 :num_simpleagg_vars 1726 ] 1727 secagg_final_accumulator_vars = combined_final_accumulator_vars[ 1728 num_simpleagg_vars: 1729 ] 1730 1731 var_init_op = tf.compat.v1.initializers.variables( 1732 server_state_vars 1733 + combined_aggregated_update_vars 1734 + combined_final_accumulator_vars 1735 ) 1736 1737 # Embeds the MapReduce form `merge` logic. 1738 merged_values = graph_helpers.import_tensorflow( 1739 'merge', mrf.merge, (final_accumulator_vars, aggregated_update_vars) 1740 ) 1741 final_accumulator_assign_ops = tf.nest.map_structure( 1742 lambda variable, tensor: variable.assign(tensor), 1743 final_accumulator_vars, 1744 merged_values, 1745 ) 1746 1747 # SecAgg tensors' aggregation is not provided in the imported TensorFlow, 1748 # but is instead fixed based on the operator (e.g. `assign_add` for 1749 # variables passed into `secure_sum`). 1750 secagg_final_accumulator_ops = _merge_secagg_vars( 1751 secure_sum_bitwidth_update_type, 1752 secure_sum_update_type, 1753 secure_modular_sum_update_type, 1754 flattened_moduli, 1755 secagg_final_accumulator_vars, 1756 secagg_aggregated_update_vars, 1757 ) 1758 final_accumulator_op = tf.group( 1759 *(final_accumulator_assign_ops + secagg_final_accumulator_ops) 1760 ).name 1761 1762 # Embeds the `report` and `update` logic, and hooks up the assignments of 1763 # the results of the final update to the server state and metric vars, to 1764 # be triggered by `apply_aggregrated_updates_op`. 1765 simpleagg_reported_values = graph_helpers.import_tensorflow( 1766 'report', mrf.report, final_accumulator_vars 1767 ) 1768 combined_final_vars = ( 1769 simpleagg_reported_values + secagg_final_accumulator_vars 1770 ) 1771 (_, server_metric_values) = graph_helpers.import_tensorflow( 1772 artifact_constants.UPDATE, 1773 mrf.update, 1774 (server_state_vars, combined_final_vars), 1775 split_outputs=True, 1776 ) 1777 1778 server_metrics_names = variable_helpers.variable_names_from_type( 1779 server_metrics_type, name=artifact_constants.SERVER_STATE_VAR_PREFIX 1780 ) 1781 1782 flattened_metrics_types = tff.structure.flatten(server_metrics_type) 1783 measurements = [ 1784 proto_helpers.make_measurement(v, s, a) 1785 for v, s, a in zip( 1786 server_metric_values, server_metrics_names, flattened_metrics_types 1787 ) 1788 ] 1789 1790 cross_round_aggregation_execution = plan_pb2.CrossRoundAggregationExecution( 1791 init_op=var_init_op.name, 1792 read_aggregated_update=read_aggregated_update, 1793 merge_op=final_accumulator_op, 1794 read_write_final_accumulators=read_write_final_accumulators, 1795 measurements=measurements, 1796 ) 1797 1798 cross_round_aggregation_execution.cross_round_aggregation_graph_bytes.Pack( 1799 cross_round_aggregation_graph.as_graph_def() 1800 ) 1801 1802 return cross_round_aggregation_execution.SerializeToString() 1803