1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC 2*14675a02SAndroid Build Coastguard Worker# 3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License"); 4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License. 5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at 6*14675a02SAndroid Build Coastguard Worker# 7*14675a02SAndroid Build Coastguard Worker# http://www.apache.org/licenses/LICENSE-2.0 8*14675a02SAndroid Build Coastguard Worker# 9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software 10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS, 11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and 13*14675a02SAndroid Build Coastguard Worker# limitations under the License. 14*14675a02SAndroid Build Coastguard Worker"""Helper methods for TensorFlow variables.""" 15*14675a02SAndroid Build Coastguard Worker 16*14675a02SAndroid Build Coastguard Workerfrom typing import Optional, Union 17*14675a02SAndroid Build Coastguard Worker 18*14675a02SAndroid Build Coastguard Workerimport tensorflow as tf 19*14675a02SAndroid Build Coastguard Workerimport tensorflow_federated as tff 20*14675a02SAndroid Build Coastguard Worker 21*14675a02SAndroid Build Coastguard Workerfrom fcp.artifact_building import tensor_utils 22*14675a02SAndroid Build Coastguard Workerfrom fcp.artifact_building import type_checks 23*14675a02SAndroid Build Coastguard Worker 24*14675a02SAndroid Build Coastguard Worker# TFF types allowed for variables created at input/output serialization 25*14675a02SAndroid Build Coastguard Worker# boundaries. 26*14675a02SAndroid Build Coastguard WorkerAllowedTffTypes = Union[tff.TensorType, tff.StructType, tff.FederatedType] 27*14675a02SAndroid Build Coastguard Worker 28*14675a02SAndroid Build Coastguard Worker 29*14675a02SAndroid Build Coastguard Worker# The prefix for the name of the sidechannel for a securely-summed variable. 30*14675a02SAndroid Build Coastguard Worker# 31*14675a02SAndroid Build Coastguard Worker# This transformed name is used as the name of the Op which *reads* from the 32*14675a02SAndroid Build Coastguard Worker# variable, rather than identifies the variable itself. Names with this prefix 33*14675a02SAndroid Build Coastguard Worker# are used as the keys in the `side_channel_tensors` map entries corresponding 34*14675a02SAndroid Build Coastguard Worker# with the variable of the unprefixed name. 35*14675a02SAndroid Build Coastguard WorkerSIDECHANNEL_NAME_PREFIX = 'sidechannel_' 36*14675a02SAndroid Build Coastguard Worker 37*14675a02SAndroid Build Coastguard Worker# `variable_names_from_type` returns the `name` argument of `tf.Variable()`. 38*14675a02SAndroid Build Coastguard Worker# However when the variable is created, the name of its tensor is actually 39*14675a02SAndroid Build Coastguard Worker# `<name>:0`. This macro is created to match this behavior. 40*14675a02SAndroid Build Coastguard Worker_TF_TENSOR_NAME_SUFFIX = ':0' 41*14675a02SAndroid Build Coastguard Worker 42*14675a02SAndroid Build Coastguard Worker 43*14675a02SAndroid Build Coastguard Workerdef _create_var_for_tff_tensor( 44*14675a02SAndroid Build Coastguard Worker tff_type: tff.TensorType, name: str, **kwargs 45*14675a02SAndroid Build Coastguard Worker) -> tf.Variable: 46*14675a02SAndroid Build Coastguard Worker """Creates a TensorFlow variable to hold a value of the `tff.TensorType`.""" 47*14675a02SAndroid Build Coastguard Worker type_checks.check_type(tff_type, tff.TensorType) 48*14675a02SAndroid Build Coastguard Worker type_checks.check_type(name, str) 49*14675a02SAndroid Build Coastguard Worker # `tff_type` can have shapes that contain `None` or `0`: 50*14675a02SAndroid Build Coastguard Worker # * `None` shape cannot be used in `tf.zeros` to create the initial value 51*14675a02SAndroid Build Coastguard Worker # of a `tf.Variable`. Hence, we replace it with a `0` in `tf.zeros`. 52*14675a02SAndroid Build Coastguard Worker # * The dimension that has `0` shape may change its shape at run time. To 53*14675a02SAndroid Build Coastguard Worker # support this, we use `None` for that dimension when creating the 54*14675a02SAndroid Build Coastguard Worker # `tf.Variable`. 55*14675a02SAndroid Build Coastguard Worker initial_value_shape = [] 56*14675a02SAndroid Build Coastguard Worker variable_shape = [] 57*14675a02SAndroid Build Coastguard Worker for shape in tff_type.shape.as_list(): 58*14675a02SAndroid Build Coastguard Worker if shape is None or shape == 0: 59*14675a02SAndroid Build Coastguard Worker initial_value_shape.append(0) 60*14675a02SAndroid Build Coastguard Worker variable_shape.append(None) 61*14675a02SAndroid Build Coastguard Worker else: 62*14675a02SAndroid Build Coastguard Worker initial_value_shape.append(shape) 63*14675a02SAndroid Build Coastguard Worker variable_shape.append(shape) 64*14675a02SAndroid Build Coastguard Worker return tf.Variable( 65*14675a02SAndroid Build Coastguard Worker initial_value=tf.zeros(shape=initial_value_shape, dtype=tff_type.dtype), 66*14675a02SAndroid Build Coastguard Worker name=name, 67*14675a02SAndroid Build Coastguard Worker dtype=tff_type.dtype, 68*14675a02SAndroid Build Coastguard Worker shape=variable_shape, 69*14675a02SAndroid Build Coastguard Worker **kwargs, 70*14675a02SAndroid Build Coastguard Worker ) 71*14675a02SAndroid Build Coastguard Worker 72*14675a02SAndroid Build Coastguard Worker 73*14675a02SAndroid Build Coastguard Worker# Build the TensorSpec for the values we will send to the client so that the 74*14675a02SAndroid Build Coastguard Worker# client graph will know how to read the incoming values. 75*14675a02SAndroid Build Coastguard Workerdef tensorspec_from_var(var: tf.Variable) -> tf.TensorSpec: 76*14675a02SAndroid Build Coastguard Worker """Builds `tf.TensorSpec` from `tf.Variables`. 77*14675a02SAndroid Build Coastguard Worker 78*14675a02SAndroid Build Coastguard Worker Args: 79*14675a02SAndroid Build Coastguard Worker var: An instance of `tf.Variable`. 80*14675a02SAndroid Build Coastguard Worker 81*14675a02SAndroid Build Coastguard Worker Returns: 82*14675a02SAndroid Build Coastguard Worker An instance of `tf.TensorSpec` corresponding to the input `tf.Variable`. 83*14675a02SAndroid Build Coastguard Worker """ 84*14675a02SAndroid Build Coastguard Worker return tf.TensorSpec( 85*14675a02SAndroid Build Coastguard Worker shape=var.shape, dtype=var.dtype, name=tensor_utils.bare_name(var.name) 86*14675a02SAndroid Build Coastguard Worker ) 87*14675a02SAndroid Build Coastguard Worker 88*14675a02SAndroid Build Coastguard Worker 89*14675a02SAndroid Build Coastguard Workerdef create_vars_for_tff_type( 90*14675a02SAndroid Build Coastguard Worker tff_type: AllowedTffTypes, name: Optional[str] = None, **kwargs 91*14675a02SAndroid Build Coastguard Worker) -> list[tf.Variable]: 92*14675a02SAndroid Build Coastguard Worker """Creates TensorFlow variables to hold a value of the given `tff_type`. 93*14675a02SAndroid Build Coastguard Worker 94*14675a02SAndroid Build Coastguard Worker The variables are created in the default graph and scope. The variables are 95*14675a02SAndroid Build Coastguard Worker automatically given `tf.zeros` initializers. 96*14675a02SAndroid Build Coastguard Worker 97*14675a02SAndroid Build Coastguard Worker Args: 98*14675a02SAndroid Build Coastguard Worker tff_type: Either a `tff.StructType`, SERVER-placed `tff.FederatedType` or a 99*14675a02SAndroid Build Coastguard Worker `tff.TensorType` object. 100*14675a02SAndroid Build Coastguard Worker name: The preferred name to use at the top-most level (if not None, must be 101*14675a02SAndroid Build Coastguard Worker a string). If `tff_type` is a `tff.StructType`, the names of the inner 102*14675a02SAndroid Build Coastguard Worker fields will be scoped under `name`, e.g. `some_name/field_name`. 103*14675a02SAndroid Build Coastguard Worker **kwargs: Optional arguments, if any, to pass to the `tf.Variable()` calls. 104*14675a02SAndroid Build Coastguard Worker 105*14675a02SAndroid Build Coastguard Worker Returns: 106*14675a02SAndroid Build Coastguard Worker A flat Python `list` of TensorFlow variable instances. 107*14675a02SAndroid Build Coastguard Worker 108*14675a02SAndroid Build Coastguard Worker Raises: 109*14675a02SAndroid Build Coastguard Worker TypeError: If the argument is of the wrong type or has the wrong placement. 110*14675a02SAndroid Build Coastguard Worker """ 111*14675a02SAndroid Build Coastguard Worker type_checks.check_type( 112*14675a02SAndroid Build Coastguard Worker tff_type, 113*14675a02SAndroid Build Coastguard Worker (tff.TensorType, tff.StructType, tff.FederatedType), 114*14675a02SAndroid Build Coastguard Worker name='tff_type', 115*14675a02SAndroid Build Coastguard Worker ) 116*14675a02SAndroid Build Coastguard Worker if name is not None: 117*14675a02SAndroid Build Coastguard Worker type_checks.check_type(name, str) 118*14675a02SAndroid Build Coastguard Worker else: 119*14675a02SAndroid Build Coastguard Worker name = 'v' 120*14675a02SAndroid Build Coastguard Worker if isinstance(tff_type, tff.TensorType): 121*14675a02SAndroid Build Coastguard Worker return [_create_var_for_tff_tensor(tff_type, name, **kwargs)] 122*14675a02SAndroid Build Coastguard Worker elif isinstance(tff_type, tff.FederatedType): 123*14675a02SAndroid Build Coastguard Worker if tff_type.placement != tff.SERVER: 124*14675a02SAndroid Build Coastguard Worker raise TypeError( 125*14675a02SAndroid Build Coastguard Worker 'Can only create vars for unplaced types or types placed ' 126*14675a02SAndroid Build Coastguard Worker 'on the SERVER.' 127*14675a02SAndroid Build Coastguard Worker ) 128*14675a02SAndroid Build Coastguard Worker return create_vars_for_tff_type(tff_type.member, name, **kwargs) 129*14675a02SAndroid Build Coastguard Worker else: # tff.StructType 130*14675a02SAndroid Build Coastguard Worker result = [] 131*14675a02SAndroid Build Coastguard Worker with tf.compat.v1.variable_scope(name): 132*14675a02SAndroid Build Coastguard Worker fields = tff.structure.to_elements(tff_type) 133*14675a02SAndroid Build Coastguard Worker for index, (field_name, field_type) in enumerate(fields): 134*14675a02SAndroid Build Coastguard Worker # Default the name of the element to its index so that we don't wind up 135*14675a02SAndroid Build Coastguard Worker # with multiple child fields listed under `/v/` 136*14675a02SAndroid Build Coastguard Worker if field_name is None: 137*14675a02SAndroid Build Coastguard Worker field_name = str(index) 138*14675a02SAndroid Build Coastguard Worker result.extend( 139*14675a02SAndroid Build Coastguard Worker create_vars_for_tff_type(field_type, name=field_name, **kwargs) 140*14675a02SAndroid Build Coastguard Worker ) 141*14675a02SAndroid Build Coastguard Worker return result 142*14675a02SAndroid Build Coastguard Worker 143*14675a02SAndroid Build Coastguard Worker 144*14675a02SAndroid Build Coastguard Workerdef variable_names_from_type( 145*14675a02SAndroid Build Coastguard Worker tff_type: AllowedTffTypes, name: str = 'v' 146*14675a02SAndroid Build Coastguard Worker) -> list[str]: 147*14675a02SAndroid Build Coastguard Worker """Creates a flattened list of variables names for the given `tff_type`. 148*14675a02SAndroid Build Coastguard Worker 149*14675a02SAndroid Build Coastguard Worker If `tff_type` is a `tff.TensorType`, the name is the `name` parameter if 150*14675a02SAndroid Build Coastguard Worker specified, otherwise a default name: `v`. If `tff_type` is a 151*14675a02SAndroid Build Coastguard Worker `tff.StructType` then '/' is used between inner and outer fields together 152*14675a02SAndroid Build Coastguard Worker with the tuple name or index of the element in the tuple. 153*14675a02SAndroid Build Coastguard Worker 154*14675a02SAndroid Build Coastguard Worker Some examples: 155*14675a02SAndroid Build Coastguard Worker 1. If the tff_type is `<'a'=tf.int32, 'b'=tf.int32>` and `name` is not 156*14675a02SAndroid Build Coastguard Worker specified, the returned variable name list is ['v/a', 'v/b']. 157*14675a02SAndroid Build Coastguard Worker 2. If the tff_type is `<tf.int32, tf.int32>` and `name` is `update`, the 158*14675a02SAndroid Build Coastguard Worker returned variable name list is ['update/0', 'update/1']. 159*14675a02SAndroid Build Coastguard Worker 3. If the tff_type is `<'a'=<'b'=tf.int32, 'c'=tf.int32>>` and `name` is 160*14675a02SAndroid Build Coastguard Worker `update`, the returned variable name list is ['update/a/b', 'update/a/c']. 161*14675a02SAndroid Build Coastguard Worker 4. If the tff_type is `<'a'=<'b'=tf.int32, 'c'=tf.int32, tf.int32>>` and 162*14675a02SAndroid Build Coastguard Worker `name` is `update`, the returned variable name list is ['update/a/b', 163*14675a02SAndroid Build Coastguard Worker 'update/a/c', 'update/a/2']. 164*14675a02SAndroid Build Coastguard Worker 165*14675a02SAndroid Build Coastguard Worker Args: 166*14675a02SAndroid Build Coastguard Worker tff_type: Either a `tff.StructType`, a `tff.FederatedType` or a 167*14675a02SAndroid Build Coastguard Worker `tff.TensorType` object. 168*14675a02SAndroid Build Coastguard Worker name: The preferred name to use at the top-most level (if not None, must be 169*14675a02SAndroid Build Coastguard Worker a string). If `tff_type` is a `tff.StructType`, the names of the inner 170*14675a02SAndroid Build Coastguard Worker fields will be scoped under `name`, e.g. `some_name/field_name`. 171*14675a02SAndroid Build Coastguard Worker 172*14675a02SAndroid Build Coastguard Worker Returns: 173*14675a02SAndroid Build Coastguard Worker A flat Python `list` of `str` names. 174*14675a02SAndroid Build Coastguard Worker 175*14675a02SAndroid Build Coastguard Worker Raises: 176*14675a02SAndroid Build Coastguard Worker TypeError: If the argument is of the wrong type. 177*14675a02SAndroid Build Coastguard Worker """ 178*14675a02SAndroid Build Coastguard Worker type_checks.check_type( 179*14675a02SAndroid Build Coastguard Worker tff_type, 180*14675a02SAndroid Build Coastguard Worker (tff.TensorType, tff.FederatedType, tff.StructType), 181*14675a02SAndroid Build Coastguard Worker name='tff_type', 182*14675a02SAndroid Build Coastguard Worker ) 183*14675a02SAndroid Build Coastguard Worker type_checks.check_type(name, str, name='name') 184*14675a02SAndroid Build Coastguard Worker if isinstance(tff_type, tff.TensorType): 185*14675a02SAndroid Build Coastguard Worker return [name] 186*14675a02SAndroid Build Coastguard Worker elif isinstance(tff_type, tff.FederatedType): 187*14675a02SAndroid Build Coastguard Worker return variable_names_from_type(tff_type.member, name) 188*14675a02SAndroid Build Coastguard Worker elif isinstance(tff_type, tff.StructType): 189*14675a02SAndroid Build Coastguard Worker result = [] 190*14675a02SAndroid Build Coastguard Worker fields = tff.structure.iter_elements(tff_type) 191*14675a02SAndroid Build Coastguard Worker for index, (field_name, field_type) in enumerate(fields): 192*14675a02SAndroid Build Coastguard Worker # Default the name of the element to its index so that we don't wind up 193*14675a02SAndroid Build Coastguard Worker # with multiple child fields listed under `/v/` 194*14675a02SAndroid Build Coastguard Worker field_name = field_name or str(index) 195*14675a02SAndroid Build Coastguard Worker result.extend( 196*14675a02SAndroid Build Coastguard Worker variable_names_from_type(field_type, name=name + '/' + field_name) 197*14675a02SAndroid Build Coastguard Worker ) 198*14675a02SAndroid Build Coastguard Worker return result 199*14675a02SAndroid Build Coastguard Worker else: 200*14675a02SAndroid Build Coastguard Worker raise TypeError( 201*14675a02SAndroid Build Coastguard Worker 'Cannot create variable names from [{t}] TFF type. ' 202*14675a02SAndroid Build Coastguard Worker 'Short-hand: {s}'.format(t=type(tff_type), s=tff_type) 203*14675a02SAndroid Build Coastguard Worker ) 204*14675a02SAndroid Build Coastguard Worker 205*14675a02SAndroid Build Coastguard Worker 206*14675a02SAndroid Build Coastguard Workerdef get_shared_secagg_tensor_names( 207*14675a02SAndroid Build Coastguard Worker intrinsic_name: str, tff_type: AllowedTffTypes 208*14675a02SAndroid Build Coastguard Worker) -> list[str]: 209*14675a02SAndroid Build Coastguard Worker """Creates the shared name of secagg tensors in client and server graph. 210*14675a02SAndroid Build Coastguard Worker 211*14675a02SAndroid Build Coastguard Worker This is the canonical function for ensuring the secagg tensor names in the 212*14675a02SAndroid Build Coastguard Worker client and server graph are the same. The server uses secagg tensor 213*14675a02SAndroid Build Coastguard Worker names as the keys to retrieve values from secagg server which are originally 214*14675a02SAndroid Build Coastguard Worker from client graph, so if the secagg tensor names in the client and server 215*14675a02SAndroid Build Coastguard Worker graph are not the same, the server could not find secagg tensors. This 216*14675a02SAndroid Build Coastguard Worker function is created to ensure this implicit dependency. 217*14675a02SAndroid Build Coastguard Worker 218*14675a02SAndroid Build Coastguard Worker Args: 219*14675a02SAndroid Build Coastguard Worker intrinsic_name: The name of the secure aggregation intrinsic being used. 220*14675a02SAndroid Build Coastguard Worker tff_type: Either a `tff.StructType`, `tff.FederatedType` or a 221*14675a02SAndroid Build Coastguard Worker `tff.TensorType` object. 222*14675a02SAndroid Build Coastguard Worker 223*14675a02SAndroid Build Coastguard Worker Returns: 224*14675a02SAndroid Build Coastguard Worker A list of variable names created from the input TFF type. 225*14675a02SAndroid Build Coastguard Worker """ 226*14675a02SAndroid Build Coastguard Worker tensor_names = variable_names_from_type( 227*14675a02SAndroid Build Coastguard Worker tff_type, f'secagg_{intrinsic_name}_update' 228*14675a02SAndroid Build Coastguard Worker ) 229*14675a02SAndroid Build Coastguard Worker return [ 230*14675a02SAndroid Build Coastguard Worker SIDECHANNEL_NAME_PREFIX + name + _TF_TENSOR_NAME_SUFFIX 231*14675a02SAndroid Build Coastguard Worker for name in tensor_names 232*14675a02SAndroid Build Coastguard Worker ] 233*14675a02SAndroid Build Coastguard Worker 234*14675a02SAndroid Build Coastguard Worker 235*14675a02SAndroid Build Coastguard Workerdef get_flattened_tensor_specs( 236*14675a02SAndroid Build Coastguard Worker tff_type: AllowedTffTypes, name: str 237*14675a02SAndroid Build Coastguard Worker) -> list[tf.TensorSpec]: 238*14675a02SAndroid Build Coastguard Worker """Generates TensorSpecs for a flattened version of the given `tff_type`. 239*14675a02SAndroid Build Coastguard Worker 240*14675a02SAndroid Build Coastguard Worker This function uses the same naming logic as `variable_names_from_type`. Please 241*14675a02SAndroid Build Coastguard Worker see that function's docstring. 242*14675a02SAndroid Build Coastguard Worker 243*14675a02SAndroid Build Coastguard Worker Args: 244*14675a02SAndroid Build Coastguard Worker tff_type: Either a `tff.StructType`, a `tff.FederatedType` or a 245*14675a02SAndroid Build Coastguard Worker `tff.TensorType` object. 246*14675a02SAndroid Build Coastguard Worker name: The preferred name to use at the top-most level (if not None, must be 247*14675a02SAndroid Build Coastguard Worker a string). If `tff_type` is a `tff.StructType`, the names of the inner 248*14675a02SAndroid Build Coastguard Worker fields will be scoped under `name`, e.g. `some_name/field_name`. 249*14675a02SAndroid Build Coastguard Worker 250*14675a02SAndroid Build Coastguard Worker Returns: 251*14675a02SAndroid Build Coastguard Worker A flat Python `list` of `TensorSpec`s. 252*14675a02SAndroid Build Coastguard Worker 253*14675a02SAndroid Build Coastguard Worker Raises: 254*14675a02SAndroid Build Coastguard Worker TypeError: If the argument is of the wrong type. 255*14675a02SAndroid Build Coastguard Worker """ 256*14675a02SAndroid Build Coastguard Worker type_checks.check_type( 257*14675a02SAndroid Build Coastguard Worker tff_type, 258*14675a02SAndroid Build Coastguard Worker (tff.TensorType, tff.FederatedType, tff.StructType), 259*14675a02SAndroid Build Coastguard Worker name='tff_type', 260*14675a02SAndroid Build Coastguard Worker ) 261*14675a02SAndroid Build Coastguard Worker type_checks.check_type(name, str, name='name') 262*14675a02SAndroid Build Coastguard Worker if isinstance(tff_type, tff.TensorType): 263*14675a02SAndroid Build Coastguard Worker return [tf.TensorSpec(tff_type.shape, tff_type.dtype, name=name)] 264*14675a02SAndroid Build Coastguard Worker elif isinstance(tff_type, tff.FederatedType): 265*14675a02SAndroid Build Coastguard Worker return get_flattened_tensor_specs(tff_type.member, name) 266*14675a02SAndroid Build Coastguard Worker elif isinstance(tff_type, tff.StructType): 267*14675a02SAndroid Build Coastguard Worker result = [] 268*14675a02SAndroid Build Coastguard Worker fields = tff.structure.iter_elements(tff_type) 269*14675a02SAndroid Build Coastguard Worker for index, (field_name, field_type) in enumerate(fields): 270*14675a02SAndroid Build Coastguard Worker # Default the name of the element to its index so that we don't wind up 271*14675a02SAndroid Build Coastguard Worker # with multiple child fields listed under `/v/` 272*14675a02SAndroid Build Coastguard Worker field_name = field_name or str(index) 273*14675a02SAndroid Build Coastguard Worker result.extend( 274*14675a02SAndroid Build Coastguard Worker get_flattened_tensor_specs(field_type, name=name + '/' + field_name) 275*14675a02SAndroid Build Coastguard Worker ) 276*14675a02SAndroid Build Coastguard Worker return result 277*14675a02SAndroid Build Coastguard Worker else: 278*14675a02SAndroid Build Coastguard Worker raise TypeError( 279*14675a02SAndroid Build Coastguard Worker 'Cannot create TensorSpecs from [{t}] TFF type. Short-hand: {s}'.format( 280*14675a02SAndroid Build Coastguard Worker t=type(tff_type), s=tff_type 281*14675a02SAndroid Build Coastguard Worker ) 282*14675a02SAndroid Build Coastguard Worker ) 283*14675a02SAndroid Build Coastguard Worker 284*14675a02SAndroid Build Coastguard Worker 285*14675a02SAndroid Build Coastguard Workerdef get_grouped_input_tensor_specs_for_aggregations( 286*14675a02SAndroid Build Coastguard Worker aggregation_comp: tff.framework.ComputationBuildingBlock, 287*14675a02SAndroid Build Coastguard Worker names: dict[int, str], 288*14675a02SAndroid Build Coastguard Worker) -> list[list[list[tf.TensorSpec]]]: 289*14675a02SAndroid Build Coastguard Worker """Gets the input TensorSpecs for an aggregation computation. 290*14675a02SAndroid Build Coastguard Worker 291*14675a02SAndroid Build Coastguard Worker This function can be used to generate the TensorSpecs that are assigned to 292*14675a02SAndroid Build Coastguard Worker ServerAggregationConfig.IntrinsicArg messages to represent the aggregation 293*14675a02SAndroid Build Coastguard Worker intrinsic calls in DistributeAggregateForm.client_to_server_aggregation. 294*14675a02SAndroid Build Coastguard Worker 295*14675a02SAndroid Build Coastguard Worker It derives the tensor name(s) for each intrinsic input argument by following 296*14675a02SAndroid Build Coastguard Worker naming logic similar to `variable_names_from_type`. DistributeAggregateForm 297*14675a02SAndroid Build Coastguard Worker does guarantee that each intrinsic input argument will be a 298*14675a02SAndroid Build Coastguard Worker `building_block.Selection` or a (potentially nested) struct of 299*14675a02SAndroid Build Coastguard Worker `building_block.Selection`s. The first element of the path is used to 300*14675a02SAndroid Build Coastguard Worker determine the top-level name, which must match the top-level name that was 301*14675a02SAndroid Build Coastguard Worker used to construct the tensor that will be getting consumed by this argument. 302*14675a02SAndroid Build Coastguard Worker 303*14675a02SAndroid Build Coastguard Worker Args: 304*14675a02SAndroid Build Coastguard Worker aggregation_comp: The aggregation computation. 305*14675a02SAndroid Build Coastguard Worker names: A dictionary describing how to map the first element of the path to a 306*14675a02SAndroid Build Coastguard Worker top-level name. 307*14675a02SAndroid Build Coastguard Worker 308*14675a02SAndroid Build Coastguard Worker Returns: 309*14675a02SAndroid Build Coastguard Worker A `list` where the ith entry represents the input tensor specs for the 310*14675a02SAndroid Build Coastguard Worker ith intrinsic in the aggregation computation. The ith entry is itself a list 311*14675a02SAndroid Build Coastguard Worker where the jth entry represents the input tensor specs for the jth argument 312*14675a02SAndroid Build Coastguard Worker of the ith intrinsic in the aggregation computation. 313*14675a02SAndroid Build Coastguard Worker 314*14675a02SAndroid Build Coastguard Worker Raises: 315*14675a02SAndroid Build Coastguard Worker TypeError: If the argument is of the wrong type. 316*14675a02SAndroid Build Coastguard Worker ValueError: If the argument contains an unexpected 317*14675a02SAndroid Build Coastguard Worker `building_block.Selection` index. 318*14675a02SAndroid Build Coastguard Worker """ 319*14675a02SAndroid Build Coastguard Worker 320*14675a02SAndroid Build Coastguard Worker def _get_selection_path( 321*14675a02SAndroid Build Coastguard Worker selection: tff.framework.ComputationBuildingBlock, 322*14675a02SAndroid Build Coastguard Worker ) -> list[int]: 323*14675a02SAndroid Build Coastguard Worker """Gets the list of selection indices for a building_blocks.Selection.""" 324*14675a02SAndroid Build Coastguard Worker 325*14675a02SAndroid Build Coastguard Worker path = [] 326*14675a02SAndroid Build Coastguard Worker while selection.is_selection(): 327*14675a02SAndroid Build Coastguard Worker path.append(selection.index) # pytype: disable=attribute-error 328*14675a02SAndroid Build Coastguard Worker selection = selection.source # pytype: disable=attribute-error 329*14675a02SAndroid Build Coastguard Worker # In ASTs like x[0][1], we'll see the last (outermost) selection first. 330*14675a02SAndroid Build Coastguard Worker path.reverse() 331*14675a02SAndroid Build Coastguard Worker return path 332*14675a02SAndroid Build Coastguard Worker 333*14675a02SAndroid Build Coastguard Worker def _get_input_tensor_specs_for_aggregation_arg( 334*14675a02SAndroid Build Coastguard Worker value: tff.framework.ComputationBuildingBlock, names: dict[int, str] 335*14675a02SAndroid Build Coastguard Worker ) -> list[tf.TensorSpec]: 336*14675a02SAndroid Build Coastguard Worker """Gets the input TensorSpecs for a single intrinsic argument.""" 337*14675a02SAndroid Build Coastguard Worker 338*14675a02SAndroid Build Coastguard Worker # An intrinsic arg may be a `building_block.Selection` or a (potentially 339*14675a02SAndroid Build Coastguard Worker # nested) struct of `building_block.Selection`s. Start by creating a 340*14675a02SAndroid Build Coastguard Worker # flattened list of the `building_block.Selection`s. 341*14675a02SAndroid Build Coastguard Worker inner_values = [] 342*14675a02SAndroid Build Coastguard Worker if value.is_struct(): 343*14675a02SAndroid Build Coastguard Worker inner_values = tff.structure.flatten(value) 344*14675a02SAndroid Build Coastguard Worker else: 345*14675a02SAndroid Build Coastguard Worker inner_values = [value] 346*14675a02SAndroid Build Coastguard Worker 347*14675a02SAndroid Build Coastguard Worker # For each `building_block.Selection`, reconstruct the tensor name that 348*14675a02SAndroid Build Coastguard Worker # will be used to supply that value. The first index of the selection path 349*14675a02SAndroid Build Coastguard Worker # indicates whether the tensor will be coming from the intermediate state 350*14675a02SAndroid Build Coastguard Worker # checkpoint (0) or from the client checkpoint (1), since TFF condenses 351*14675a02SAndroid Build Coastguard Worker # daf.client_to_server_aggregation(temp_server_state, client_update) 352*14675a02SAndroid Build Coastguard Worker # into a 1-arg function. Since the tensors within the checkpoints 353*14675a02SAndroid Build Coastguard Worker # corresponding to temp_server_state and work_at_clients will be named using 354*14675a02SAndroid Build Coastguard Worker # variable_names_from_type, which uses a simple filepath-like naming pattern 355*14675a02SAndroid Build Coastguard Worker # to refer to the tensors within a struct, we can reconstruct the relevant 356*14675a02SAndroid Build Coastguard Worker # tensor name by concatenating together the remaining indices of each 357*14675a02SAndroid Build Coastguard Worker # selection path. 358*14675a02SAndroid Build Coastguard Worker tensor_specs = [] 359*14675a02SAndroid Build Coastguard Worker for inner_value in inner_values: 360*14675a02SAndroid Build Coastguard Worker inner_value.check_selection() 361*14675a02SAndroid Build Coastguard Worker path = _get_selection_path(inner_value) 362*14675a02SAndroid Build Coastguard Worker arg_index = path[0] 363*14675a02SAndroid Build Coastguard Worker if arg_index in names: 364*14675a02SAndroid Build Coastguard Worker prefix = names[arg_index] 365*14675a02SAndroid Build Coastguard Worker else: 366*14675a02SAndroid Build Coastguard Worker raise ValueError('Unexpected arg index for aggregation selection') 367*14675a02SAndroid Build Coastguard Worker prefix += '/' + '/'.join([str(x) for x in path[1:]]) 368*14675a02SAndroid Build Coastguard Worker tensor_specs.extend( 369*14675a02SAndroid Build Coastguard Worker get_flattened_tensor_specs(inner_value.type_signature, name=prefix) 370*14675a02SAndroid Build Coastguard Worker ) 371*14675a02SAndroid Build Coastguard Worker 372*14675a02SAndroid Build Coastguard Worker return tensor_specs 373*14675a02SAndroid Build Coastguard Worker 374*14675a02SAndroid Build Coastguard Worker grouped_input_tensor_specs = [] 375*14675a02SAndroid Build Coastguard Worker 376*14675a02SAndroid Build Coastguard Worker for _, local_value in aggregation_comp.result.locals: # pytype: disable=attribute-error 377*14675a02SAndroid Build Coastguard Worker local_value.check_call() 378*14675a02SAndroid Build Coastguard Worker local_value.function.check_intrinsic() 379*14675a02SAndroid Build Coastguard Worker assert local_value.function.intrinsic_def().aggregation_kind 380*14675a02SAndroid Build Coastguard Worker 381*14675a02SAndroid Build Coastguard Worker # Collect the input TensorFlowSpecs for each argument for this intrinsic. 382*14675a02SAndroid Build Coastguard Worker input_tensor_specs_for_intrinsic = [] 383*14675a02SAndroid Build Coastguard Worker if ( 384*14675a02SAndroid Build Coastguard Worker local_value.function.intrinsic_def().type_signature.parameter.is_struct() 385*14675a02SAndroid Build Coastguard Worker ): 386*14675a02SAndroid Build Coastguard Worker for element in local_value.argument.children(): 387*14675a02SAndroid Build Coastguard Worker input_tensor_specs_for_intrinsic.append( 388*14675a02SAndroid Build Coastguard Worker _get_input_tensor_specs_for_aggregation_arg(element, names) 389*14675a02SAndroid Build Coastguard Worker ) 390*14675a02SAndroid Build Coastguard Worker else: 391*14675a02SAndroid Build Coastguard Worker input_tensor_specs_for_intrinsic.append( 392*14675a02SAndroid Build Coastguard Worker _get_input_tensor_specs_for_aggregation_arg( 393*14675a02SAndroid Build Coastguard Worker local_value.argument, names 394*14675a02SAndroid Build Coastguard Worker ) 395*14675a02SAndroid Build Coastguard Worker ) 396*14675a02SAndroid Build Coastguard Worker 397*14675a02SAndroid Build Coastguard Worker grouped_input_tensor_specs.append(input_tensor_specs_for_intrinsic) 398*14675a02SAndroid Build Coastguard Worker 399*14675a02SAndroid Build Coastguard Worker return grouped_input_tensor_specs 400*14675a02SAndroid Build Coastguard Worker 401*14675a02SAndroid Build Coastguard Worker 402*14675a02SAndroid Build Coastguard Workerdef get_grouped_output_tensor_specs_for_aggregations( 403*14675a02SAndroid Build Coastguard Worker aggregation_comp: tff.framework.ComputationBuildingBlock, 404*14675a02SAndroid Build Coastguard Worker) -> list[list[tf.TensorSpec]]: 405*14675a02SAndroid Build Coastguard Worker """Gets the output TensorSpecs for an aggregation computation. 406*14675a02SAndroid Build Coastguard Worker 407*14675a02SAndroid Build Coastguard Worker This function can be used to generate the TensorSpecs that are assigned 408*14675a02SAndroid Build Coastguard Worker to the output_tensors value in ServerAggregationConfig messages to represent 409*14675a02SAndroid Build Coastguard Worker the aggregation intrinsic calls in 410*14675a02SAndroid Build Coastguard Worker DistributeAggregateForm.client_to_server_aggregation. 411*14675a02SAndroid Build Coastguard Worker 412*14675a02SAndroid Build Coastguard Worker It derives the tensor name(s) for each intrinsic output argument by following 413*14675a02SAndroid Build Coastguard Worker naming logic similar to `variable_names_from_type`. It must produce tensor 414*14675a02SAndroid Build Coastguard Worker names that match the tensor names that are expected by the post-aggregation 415*14675a02SAndroid Build Coastguard Worker computation. 416*14675a02SAndroid Build Coastguard Worker 417*14675a02SAndroid Build Coastguard Worker Args: 418*14675a02SAndroid Build Coastguard Worker aggregation_comp: The aggregation computation. 419*14675a02SAndroid Build Coastguard Worker 420*14675a02SAndroid Build Coastguard Worker Returns: 421*14675a02SAndroid Build Coastguard Worker A list where the ith entry represents the output tensor specs for the ith 422*14675a02SAndroid Build Coastguard Worker intrinsic in the aggregation computation. 423*14675a02SAndroid Build Coastguard Worker 424*14675a02SAndroid Build Coastguard Worker Raises: 425*14675a02SAndroid Build Coastguard Worker TypeError: If the argument is of the wrong type. 426*14675a02SAndroid Build Coastguard Worker """ 427*14675a02SAndroid Build Coastguard Worker # TensorflowSpecs for all the intrinsic results. These TensorflowSpecs must 428*14675a02SAndroid Build Coastguard Worker # have names that mirror the result of calling variable_names_from_type on 429*14675a02SAndroid Build Coastguard Worker # the output type of DistributeAggregateForm.client_to_server_aggregation 430*14675a02SAndroid Build Coastguard Worker # (which is the same as the type of the aggregation result input arg in 431*14675a02SAndroid Build Coastguard Worker # DistributeAggregateForm.server_result). 432*14675a02SAndroid Build Coastguard Worker output_tensor_specs = get_flattened_tensor_specs( 433*14675a02SAndroid Build Coastguard Worker tff.StructType([aggregation_comp.type_signature.result]), 434*14675a02SAndroid Build Coastguard Worker name='intermediate_update', 435*14675a02SAndroid Build Coastguard Worker ) 436*14675a02SAndroid Build Coastguard Worker output_tensor_spec_index = 0 437*14675a02SAndroid Build Coastguard Worker 438*14675a02SAndroid Build Coastguard Worker grouped_output_tensor_specs = [] 439*14675a02SAndroid Build Coastguard Worker 440*14675a02SAndroid Build Coastguard Worker for _, local_value in aggregation_comp.result.locals: # pytype: disable=attribute-error 441*14675a02SAndroid Build Coastguard Worker local_value.check_call() 442*14675a02SAndroid Build Coastguard Worker local_value.function.check_intrinsic() 443*14675a02SAndroid Build Coastguard Worker local_value.type_signature.check_federated() 444*14675a02SAndroid Build Coastguard Worker assert local_value.function.intrinsic_def().aggregation_kind 445*14675a02SAndroid Build Coastguard Worker 446*14675a02SAndroid Build Coastguard Worker tensor_specs = [] 447*14675a02SAndroid Build Coastguard Worker # If the output is a struct, select the appropriate number of 448*14675a02SAndroid Build Coastguard Worker # TensorflowSpecs. 449*14675a02SAndroid Build Coastguard Worker if local_value.type_signature.member.is_struct(): 450*14675a02SAndroid Build Coastguard Worker num_specs = len(tff.structure.flatten(local_value.type_signature.member)) 451*14675a02SAndroid Build Coastguard Worker tensor_specs = output_tensor_specs[ 452*14675a02SAndroid Build Coastguard Worker output_tensor_spec_index : output_tensor_spec_index + num_specs 453*14675a02SAndroid Build Coastguard Worker ] 454*14675a02SAndroid Build Coastguard Worker output_tensor_spec_index += num_specs 455*14675a02SAndroid Build Coastguard Worker else: 456*14675a02SAndroid Build Coastguard Worker tensor_specs.append(output_tensor_specs[output_tensor_spec_index]) 457*14675a02SAndroid Build Coastguard Worker output_tensor_spec_index += 1 458*14675a02SAndroid Build Coastguard Worker grouped_output_tensor_specs.append(tensor_specs) 459*14675a02SAndroid Build Coastguard Worker 460*14675a02SAndroid Build Coastguard Worker return grouped_output_tensor_specs 461