xref: /aosp_15_r20/external/federated-compute/fcp/artifact_building/variable_helpers.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC
2*14675a02SAndroid Build Coastguard Worker#
3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License");
4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License.
5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at
6*14675a02SAndroid Build Coastguard Worker#
7*14675a02SAndroid Build Coastguard Worker#      http://www.apache.org/licenses/LICENSE-2.0
8*14675a02SAndroid Build Coastguard Worker#
9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software
10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS,
11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and
13*14675a02SAndroid Build Coastguard Worker# limitations under the License.
14*14675a02SAndroid Build Coastguard Worker"""Helper methods for TensorFlow variables."""
15*14675a02SAndroid Build Coastguard Worker
16*14675a02SAndroid Build Coastguard Workerfrom typing import Optional, Union
17*14675a02SAndroid Build Coastguard Worker
18*14675a02SAndroid Build Coastguard Workerimport tensorflow as tf
19*14675a02SAndroid Build Coastguard Workerimport tensorflow_federated as tff
20*14675a02SAndroid Build Coastguard Worker
21*14675a02SAndroid Build Coastguard Workerfrom fcp.artifact_building import tensor_utils
22*14675a02SAndroid Build Coastguard Workerfrom fcp.artifact_building import type_checks
23*14675a02SAndroid Build Coastguard Worker
24*14675a02SAndroid Build Coastguard Worker# TFF types allowed for variables created at input/output serialization
25*14675a02SAndroid Build Coastguard Worker# boundaries.
26*14675a02SAndroid Build Coastguard WorkerAllowedTffTypes = Union[tff.TensorType, tff.StructType, tff.FederatedType]
27*14675a02SAndroid Build Coastguard Worker
28*14675a02SAndroid Build Coastguard Worker
29*14675a02SAndroid Build Coastguard Worker# The prefix for the name of the sidechannel for a securely-summed variable.
30*14675a02SAndroid Build Coastguard Worker#
31*14675a02SAndroid Build Coastguard Worker# This transformed name is used as the name of the Op which *reads* from the
32*14675a02SAndroid Build Coastguard Worker# variable, rather than identifies the variable itself. Names with this prefix
33*14675a02SAndroid Build Coastguard Worker# are used as the keys in the `side_channel_tensors` map entries corresponding
34*14675a02SAndroid Build Coastguard Worker# with the variable of the unprefixed name.
35*14675a02SAndroid Build Coastguard WorkerSIDECHANNEL_NAME_PREFIX = 'sidechannel_'
36*14675a02SAndroid Build Coastguard Worker
37*14675a02SAndroid Build Coastguard Worker# `variable_names_from_type` returns the `name` argument of `tf.Variable()`.
38*14675a02SAndroid Build Coastguard Worker# However when the variable is created, the name of its tensor is actually
39*14675a02SAndroid Build Coastguard Worker# `<name>:0`. This macro is created to match this behavior.
40*14675a02SAndroid Build Coastguard Worker_TF_TENSOR_NAME_SUFFIX = ':0'
41*14675a02SAndroid Build Coastguard Worker
42*14675a02SAndroid Build Coastguard Worker
43*14675a02SAndroid Build Coastguard Workerdef _create_var_for_tff_tensor(
44*14675a02SAndroid Build Coastguard Worker    tff_type: tff.TensorType, name: str, **kwargs
45*14675a02SAndroid Build Coastguard Worker) -> tf.Variable:
46*14675a02SAndroid Build Coastguard Worker  """Creates a TensorFlow variable to hold a value of the `tff.TensorType`."""
47*14675a02SAndroid Build Coastguard Worker  type_checks.check_type(tff_type, tff.TensorType)
48*14675a02SAndroid Build Coastguard Worker  type_checks.check_type(name, str)
49*14675a02SAndroid Build Coastguard Worker  # `tff_type` can have shapes that contain `None` or `0`:
50*14675a02SAndroid Build Coastguard Worker  # * `None` shape cannot be used in `tf.zeros` to create the initial value
51*14675a02SAndroid Build Coastguard Worker  #   of a `tf.Variable`. Hence, we replace it with a `0` in `tf.zeros`.
52*14675a02SAndroid Build Coastguard Worker  # * The dimension that has `0` shape may change its shape at run time. To
53*14675a02SAndroid Build Coastguard Worker  #   support this, we use `None` for that dimension when creating the
54*14675a02SAndroid Build Coastguard Worker  #   `tf.Variable`.
55*14675a02SAndroid Build Coastguard Worker  initial_value_shape = []
56*14675a02SAndroid Build Coastguard Worker  variable_shape = []
57*14675a02SAndroid Build Coastguard Worker  for shape in tff_type.shape.as_list():
58*14675a02SAndroid Build Coastguard Worker    if shape is None or shape == 0:
59*14675a02SAndroid Build Coastguard Worker      initial_value_shape.append(0)
60*14675a02SAndroid Build Coastguard Worker      variable_shape.append(None)
61*14675a02SAndroid Build Coastguard Worker    else:
62*14675a02SAndroid Build Coastguard Worker      initial_value_shape.append(shape)
63*14675a02SAndroid Build Coastguard Worker      variable_shape.append(shape)
64*14675a02SAndroid Build Coastguard Worker  return tf.Variable(
65*14675a02SAndroid Build Coastguard Worker      initial_value=tf.zeros(shape=initial_value_shape, dtype=tff_type.dtype),
66*14675a02SAndroid Build Coastguard Worker      name=name,
67*14675a02SAndroid Build Coastguard Worker      dtype=tff_type.dtype,
68*14675a02SAndroid Build Coastguard Worker      shape=variable_shape,
69*14675a02SAndroid Build Coastguard Worker      **kwargs,
70*14675a02SAndroid Build Coastguard Worker  )
71*14675a02SAndroid Build Coastguard Worker
72*14675a02SAndroid Build Coastguard Worker
73*14675a02SAndroid Build Coastguard Worker# Build the TensorSpec for the values we will send to the client so that the
74*14675a02SAndroid Build Coastguard Worker# client graph will know how to read the incoming values.
75*14675a02SAndroid Build Coastguard Workerdef tensorspec_from_var(var: tf.Variable) -> tf.TensorSpec:
76*14675a02SAndroid Build Coastguard Worker  """Builds `tf.TensorSpec` from `tf.Variables`.
77*14675a02SAndroid Build Coastguard Worker
78*14675a02SAndroid Build Coastguard Worker  Args:
79*14675a02SAndroid Build Coastguard Worker    var: An instance of `tf.Variable`.
80*14675a02SAndroid Build Coastguard Worker
81*14675a02SAndroid Build Coastguard Worker  Returns:
82*14675a02SAndroid Build Coastguard Worker    An instance of `tf.TensorSpec` corresponding to the input `tf.Variable`.
83*14675a02SAndroid Build Coastguard Worker  """
84*14675a02SAndroid Build Coastguard Worker  return tf.TensorSpec(
85*14675a02SAndroid Build Coastguard Worker      shape=var.shape, dtype=var.dtype, name=tensor_utils.bare_name(var.name)
86*14675a02SAndroid Build Coastguard Worker  )
87*14675a02SAndroid Build Coastguard Worker
88*14675a02SAndroid Build Coastguard Worker
89*14675a02SAndroid Build Coastguard Workerdef create_vars_for_tff_type(
90*14675a02SAndroid Build Coastguard Worker    tff_type: AllowedTffTypes, name: Optional[str] = None, **kwargs
91*14675a02SAndroid Build Coastguard Worker) -> list[tf.Variable]:
92*14675a02SAndroid Build Coastguard Worker  """Creates TensorFlow variables to hold a value of the given `tff_type`.
93*14675a02SAndroid Build Coastguard Worker
94*14675a02SAndroid Build Coastguard Worker  The variables are created in the default graph and scope. The variables are
95*14675a02SAndroid Build Coastguard Worker  automatically given `tf.zeros` initializers.
96*14675a02SAndroid Build Coastguard Worker
97*14675a02SAndroid Build Coastguard Worker  Args:
98*14675a02SAndroid Build Coastguard Worker    tff_type: Either a `tff.StructType`, SERVER-placed `tff.FederatedType` or a
99*14675a02SAndroid Build Coastguard Worker      `tff.TensorType` object.
100*14675a02SAndroid Build Coastguard Worker    name: The preferred name to use at the top-most level (if not None, must be
101*14675a02SAndroid Build Coastguard Worker      a string). If `tff_type` is a `tff.StructType`, the names of the inner
102*14675a02SAndroid Build Coastguard Worker      fields will be scoped under `name`, e.g. `some_name/field_name`.
103*14675a02SAndroid Build Coastguard Worker    **kwargs: Optional arguments, if any, to pass to the `tf.Variable()` calls.
104*14675a02SAndroid Build Coastguard Worker
105*14675a02SAndroid Build Coastguard Worker  Returns:
106*14675a02SAndroid Build Coastguard Worker    A flat Python `list` of TensorFlow variable instances.
107*14675a02SAndroid Build Coastguard Worker
108*14675a02SAndroid Build Coastguard Worker  Raises:
109*14675a02SAndroid Build Coastguard Worker    TypeError: If the argument is of the wrong type or has the wrong placement.
110*14675a02SAndroid Build Coastguard Worker  """
111*14675a02SAndroid Build Coastguard Worker  type_checks.check_type(
112*14675a02SAndroid Build Coastguard Worker      tff_type,
113*14675a02SAndroid Build Coastguard Worker      (tff.TensorType, tff.StructType, tff.FederatedType),
114*14675a02SAndroid Build Coastguard Worker      name='tff_type',
115*14675a02SAndroid Build Coastguard Worker  )
116*14675a02SAndroid Build Coastguard Worker  if name is not None:
117*14675a02SAndroid Build Coastguard Worker    type_checks.check_type(name, str)
118*14675a02SAndroid Build Coastguard Worker  else:
119*14675a02SAndroid Build Coastguard Worker    name = 'v'
120*14675a02SAndroid Build Coastguard Worker  if isinstance(tff_type, tff.TensorType):
121*14675a02SAndroid Build Coastguard Worker    return [_create_var_for_tff_tensor(tff_type, name, **kwargs)]
122*14675a02SAndroid Build Coastguard Worker  elif isinstance(tff_type, tff.FederatedType):
123*14675a02SAndroid Build Coastguard Worker    if tff_type.placement != tff.SERVER:
124*14675a02SAndroid Build Coastguard Worker      raise TypeError(
125*14675a02SAndroid Build Coastguard Worker          'Can only create vars for unplaced types or types placed '
126*14675a02SAndroid Build Coastguard Worker          'on the SERVER.'
127*14675a02SAndroid Build Coastguard Worker      )
128*14675a02SAndroid Build Coastguard Worker    return create_vars_for_tff_type(tff_type.member, name, **kwargs)
129*14675a02SAndroid Build Coastguard Worker  else:  # tff.StructType
130*14675a02SAndroid Build Coastguard Worker    result = []
131*14675a02SAndroid Build Coastguard Worker    with tf.compat.v1.variable_scope(name):
132*14675a02SAndroid Build Coastguard Worker      fields = tff.structure.to_elements(tff_type)
133*14675a02SAndroid Build Coastguard Worker      for index, (field_name, field_type) in enumerate(fields):
134*14675a02SAndroid Build Coastguard Worker        # Default the name of the element to its index so that we don't wind up
135*14675a02SAndroid Build Coastguard Worker        # with multiple child fields listed under `/v/`
136*14675a02SAndroid Build Coastguard Worker        if field_name is None:
137*14675a02SAndroid Build Coastguard Worker          field_name = str(index)
138*14675a02SAndroid Build Coastguard Worker        result.extend(
139*14675a02SAndroid Build Coastguard Worker            create_vars_for_tff_type(field_type, name=field_name, **kwargs)
140*14675a02SAndroid Build Coastguard Worker        )
141*14675a02SAndroid Build Coastguard Worker    return result
142*14675a02SAndroid Build Coastguard Worker
143*14675a02SAndroid Build Coastguard Worker
144*14675a02SAndroid Build Coastguard Workerdef variable_names_from_type(
145*14675a02SAndroid Build Coastguard Worker    tff_type: AllowedTffTypes, name: str = 'v'
146*14675a02SAndroid Build Coastguard Worker) -> list[str]:
147*14675a02SAndroid Build Coastguard Worker  """Creates a flattened list of variables names for the given `tff_type`.
148*14675a02SAndroid Build Coastguard Worker
149*14675a02SAndroid Build Coastguard Worker  If `tff_type` is a `tff.TensorType`, the name is the `name` parameter if
150*14675a02SAndroid Build Coastguard Worker  specified, otherwise a default name: `v`. If `tff_type` is a
151*14675a02SAndroid Build Coastguard Worker  `tff.StructType` then '/' is used between inner and outer fields together
152*14675a02SAndroid Build Coastguard Worker  with the tuple name or index of the element in the tuple.
153*14675a02SAndroid Build Coastguard Worker
154*14675a02SAndroid Build Coastguard Worker  Some examples:
155*14675a02SAndroid Build Coastguard Worker  1. If the tff_type is `<'a'=tf.int32, 'b'=tf.int32>` and `name` is not
156*14675a02SAndroid Build Coastguard Worker    specified, the returned variable name list is ['v/a', 'v/b'].
157*14675a02SAndroid Build Coastguard Worker  2. If the tff_type is `<tf.int32, tf.int32>` and `name` is `update`, the
158*14675a02SAndroid Build Coastguard Worker    returned variable name list is ['update/0', 'update/1'].
159*14675a02SAndroid Build Coastguard Worker  3. If the tff_type is `<'a'=<'b'=tf.int32, 'c'=tf.int32>>` and `name` is
160*14675a02SAndroid Build Coastguard Worker    `update`, the returned variable name list is ['update/a/b', 'update/a/c'].
161*14675a02SAndroid Build Coastguard Worker  4. If the tff_type is `<'a'=<'b'=tf.int32, 'c'=tf.int32, tf.int32>>` and
162*14675a02SAndroid Build Coastguard Worker    `name` is `update`, the returned variable name list is ['update/a/b',
163*14675a02SAndroid Build Coastguard Worker    'update/a/c', 'update/a/2'].
164*14675a02SAndroid Build Coastguard Worker
165*14675a02SAndroid Build Coastguard Worker  Args:
166*14675a02SAndroid Build Coastguard Worker    tff_type: Either a `tff.StructType`, a `tff.FederatedType` or a
167*14675a02SAndroid Build Coastguard Worker      `tff.TensorType` object.
168*14675a02SAndroid Build Coastguard Worker    name: The preferred name to use at the top-most level (if not None, must be
169*14675a02SAndroid Build Coastguard Worker      a string). If `tff_type` is a `tff.StructType`, the names of the inner
170*14675a02SAndroid Build Coastguard Worker      fields will be scoped under `name`, e.g. `some_name/field_name`.
171*14675a02SAndroid Build Coastguard Worker
172*14675a02SAndroid Build Coastguard Worker  Returns:
173*14675a02SAndroid Build Coastguard Worker    A flat Python `list` of `str` names.
174*14675a02SAndroid Build Coastguard Worker
175*14675a02SAndroid Build Coastguard Worker  Raises:
176*14675a02SAndroid Build Coastguard Worker    TypeError: If the argument is of the wrong type.
177*14675a02SAndroid Build Coastguard Worker  """
178*14675a02SAndroid Build Coastguard Worker  type_checks.check_type(
179*14675a02SAndroid Build Coastguard Worker      tff_type,
180*14675a02SAndroid Build Coastguard Worker      (tff.TensorType, tff.FederatedType, tff.StructType),
181*14675a02SAndroid Build Coastguard Worker      name='tff_type',
182*14675a02SAndroid Build Coastguard Worker  )
183*14675a02SAndroid Build Coastguard Worker  type_checks.check_type(name, str, name='name')
184*14675a02SAndroid Build Coastguard Worker  if isinstance(tff_type, tff.TensorType):
185*14675a02SAndroid Build Coastguard Worker    return [name]
186*14675a02SAndroid Build Coastguard Worker  elif isinstance(tff_type, tff.FederatedType):
187*14675a02SAndroid Build Coastguard Worker    return variable_names_from_type(tff_type.member, name)
188*14675a02SAndroid Build Coastguard Worker  elif isinstance(tff_type, tff.StructType):
189*14675a02SAndroid Build Coastguard Worker    result = []
190*14675a02SAndroid Build Coastguard Worker    fields = tff.structure.iter_elements(tff_type)
191*14675a02SAndroid Build Coastguard Worker    for index, (field_name, field_type) in enumerate(fields):
192*14675a02SAndroid Build Coastguard Worker      # Default the name of the element to its index so that we don't wind up
193*14675a02SAndroid Build Coastguard Worker      # with multiple child fields listed under `/v/`
194*14675a02SAndroid Build Coastguard Worker      field_name = field_name or str(index)
195*14675a02SAndroid Build Coastguard Worker      result.extend(
196*14675a02SAndroid Build Coastguard Worker          variable_names_from_type(field_type, name=name + '/' + field_name)
197*14675a02SAndroid Build Coastguard Worker      )
198*14675a02SAndroid Build Coastguard Worker    return result
199*14675a02SAndroid Build Coastguard Worker  else:
200*14675a02SAndroid Build Coastguard Worker    raise TypeError(
201*14675a02SAndroid Build Coastguard Worker        'Cannot create variable names from [{t}] TFF type. '
202*14675a02SAndroid Build Coastguard Worker        'Short-hand: {s}'.format(t=type(tff_type), s=tff_type)
203*14675a02SAndroid Build Coastguard Worker    )
204*14675a02SAndroid Build Coastguard Worker
205*14675a02SAndroid Build Coastguard Worker
206*14675a02SAndroid Build Coastguard Workerdef get_shared_secagg_tensor_names(
207*14675a02SAndroid Build Coastguard Worker    intrinsic_name: str, tff_type: AllowedTffTypes
208*14675a02SAndroid Build Coastguard Worker) -> list[str]:
209*14675a02SAndroid Build Coastguard Worker  """Creates the shared name of secagg tensors in client and server graph.
210*14675a02SAndroid Build Coastguard Worker
211*14675a02SAndroid Build Coastguard Worker  This is the canonical function for ensuring the secagg tensor names in the
212*14675a02SAndroid Build Coastguard Worker  client and server graph are the same. The server uses secagg tensor
213*14675a02SAndroid Build Coastguard Worker  names as the keys to retrieve values from secagg server which are originally
214*14675a02SAndroid Build Coastguard Worker  from client graph, so if the secagg tensor names in the client and server
215*14675a02SAndroid Build Coastguard Worker  graph are not the same, the server could not find secagg tensors. This
216*14675a02SAndroid Build Coastguard Worker  function is created to ensure this implicit dependency.
217*14675a02SAndroid Build Coastguard Worker
218*14675a02SAndroid Build Coastguard Worker  Args:
219*14675a02SAndroid Build Coastguard Worker    intrinsic_name: The name of the secure aggregation intrinsic being used.
220*14675a02SAndroid Build Coastguard Worker    tff_type: Either a `tff.StructType`, `tff.FederatedType` or a
221*14675a02SAndroid Build Coastguard Worker      `tff.TensorType` object.
222*14675a02SAndroid Build Coastguard Worker
223*14675a02SAndroid Build Coastguard Worker  Returns:
224*14675a02SAndroid Build Coastguard Worker    A list of variable names created from the input TFF type.
225*14675a02SAndroid Build Coastguard Worker  """
226*14675a02SAndroid Build Coastguard Worker  tensor_names = variable_names_from_type(
227*14675a02SAndroid Build Coastguard Worker      tff_type, f'secagg_{intrinsic_name}_update'
228*14675a02SAndroid Build Coastguard Worker  )
229*14675a02SAndroid Build Coastguard Worker  return [
230*14675a02SAndroid Build Coastguard Worker      SIDECHANNEL_NAME_PREFIX + name + _TF_TENSOR_NAME_SUFFIX
231*14675a02SAndroid Build Coastguard Worker      for name in tensor_names
232*14675a02SAndroid Build Coastguard Worker  ]
233*14675a02SAndroid Build Coastguard Worker
234*14675a02SAndroid Build Coastguard Worker
235*14675a02SAndroid Build Coastguard Workerdef get_flattened_tensor_specs(
236*14675a02SAndroid Build Coastguard Worker    tff_type: AllowedTffTypes, name: str
237*14675a02SAndroid Build Coastguard Worker) -> list[tf.TensorSpec]:
238*14675a02SAndroid Build Coastguard Worker  """Generates TensorSpecs for a flattened version of the given `tff_type`.
239*14675a02SAndroid Build Coastguard Worker
240*14675a02SAndroid Build Coastguard Worker  This function uses the same naming logic as `variable_names_from_type`. Please
241*14675a02SAndroid Build Coastguard Worker  see that function's docstring.
242*14675a02SAndroid Build Coastguard Worker
243*14675a02SAndroid Build Coastguard Worker  Args:
244*14675a02SAndroid Build Coastguard Worker    tff_type: Either a `tff.StructType`, a `tff.FederatedType` or a
245*14675a02SAndroid Build Coastguard Worker      `tff.TensorType` object.
246*14675a02SAndroid Build Coastguard Worker    name: The preferred name to use at the top-most level (if not None, must be
247*14675a02SAndroid Build Coastguard Worker      a string). If `tff_type` is a `tff.StructType`, the names of the inner
248*14675a02SAndroid Build Coastguard Worker      fields will be scoped under `name`, e.g. `some_name/field_name`.
249*14675a02SAndroid Build Coastguard Worker
250*14675a02SAndroid Build Coastguard Worker  Returns:
251*14675a02SAndroid Build Coastguard Worker    A flat Python `list` of `TensorSpec`s.
252*14675a02SAndroid Build Coastguard Worker
253*14675a02SAndroid Build Coastguard Worker  Raises:
254*14675a02SAndroid Build Coastguard Worker    TypeError: If the argument is of the wrong type.
255*14675a02SAndroid Build Coastguard Worker  """
256*14675a02SAndroid Build Coastguard Worker  type_checks.check_type(
257*14675a02SAndroid Build Coastguard Worker      tff_type,
258*14675a02SAndroid Build Coastguard Worker      (tff.TensorType, tff.FederatedType, tff.StructType),
259*14675a02SAndroid Build Coastguard Worker      name='tff_type',
260*14675a02SAndroid Build Coastguard Worker  )
261*14675a02SAndroid Build Coastguard Worker  type_checks.check_type(name, str, name='name')
262*14675a02SAndroid Build Coastguard Worker  if isinstance(tff_type, tff.TensorType):
263*14675a02SAndroid Build Coastguard Worker    return [tf.TensorSpec(tff_type.shape, tff_type.dtype, name=name)]
264*14675a02SAndroid Build Coastguard Worker  elif isinstance(tff_type, tff.FederatedType):
265*14675a02SAndroid Build Coastguard Worker    return get_flattened_tensor_specs(tff_type.member, name)
266*14675a02SAndroid Build Coastguard Worker  elif isinstance(tff_type, tff.StructType):
267*14675a02SAndroid Build Coastguard Worker    result = []
268*14675a02SAndroid Build Coastguard Worker    fields = tff.structure.iter_elements(tff_type)
269*14675a02SAndroid Build Coastguard Worker    for index, (field_name, field_type) in enumerate(fields):
270*14675a02SAndroid Build Coastguard Worker      # Default the name of the element to its index so that we don't wind up
271*14675a02SAndroid Build Coastguard Worker      # with multiple child fields listed under `/v/`
272*14675a02SAndroid Build Coastguard Worker      field_name = field_name or str(index)
273*14675a02SAndroid Build Coastguard Worker      result.extend(
274*14675a02SAndroid Build Coastguard Worker          get_flattened_tensor_specs(field_type, name=name + '/' + field_name)
275*14675a02SAndroid Build Coastguard Worker      )
276*14675a02SAndroid Build Coastguard Worker    return result
277*14675a02SAndroid Build Coastguard Worker  else:
278*14675a02SAndroid Build Coastguard Worker    raise TypeError(
279*14675a02SAndroid Build Coastguard Worker        'Cannot create TensorSpecs from [{t}] TFF type. Short-hand: {s}'.format(
280*14675a02SAndroid Build Coastguard Worker            t=type(tff_type), s=tff_type
281*14675a02SAndroid Build Coastguard Worker        )
282*14675a02SAndroid Build Coastguard Worker    )
283*14675a02SAndroid Build Coastguard Worker
284*14675a02SAndroid Build Coastguard Worker
285*14675a02SAndroid Build Coastguard Workerdef get_grouped_input_tensor_specs_for_aggregations(
286*14675a02SAndroid Build Coastguard Worker    aggregation_comp: tff.framework.ComputationBuildingBlock,
287*14675a02SAndroid Build Coastguard Worker    names: dict[int, str],
288*14675a02SAndroid Build Coastguard Worker) -> list[list[list[tf.TensorSpec]]]:
289*14675a02SAndroid Build Coastguard Worker  """Gets the input TensorSpecs for an aggregation computation.
290*14675a02SAndroid Build Coastguard Worker
291*14675a02SAndroid Build Coastguard Worker  This function can be used to generate the TensorSpecs that are assigned to
292*14675a02SAndroid Build Coastguard Worker  ServerAggregationConfig.IntrinsicArg messages to represent the aggregation
293*14675a02SAndroid Build Coastguard Worker  intrinsic calls in DistributeAggregateForm.client_to_server_aggregation.
294*14675a02SAndroid Build Coastguard Worker
295*14675a02SAndroid Build Coastguard Worker  It derives the tensor name(s) for each intrinsic input argument by following
296*14675a02SAndroid Build Coastguard Worker  naming logic similar to `variable_names_from_type`. DistributeAggregateForm
297*14675a02SAndroid Build Coastguard Worker  does guarantee that each intrinsic input argument will be a
298*14675a02SAndroid Build Coastguard Worker  `building_block.Selection` or a (potentially nested) struct of
299*14675a02SAndroid Build Coastguard Worker  `building_block.Selection`s. The first element of the path is used to
300*14675a02SAndroid Build Coastguard Worker  determine the top-level name, which must match the top-level name that was
301*14675a02SAndroid Build Coastguard Worker  used to construct the tensor that will be getting consumed by this argument.
302*14675a02SAndroid Build Coastguard Worker
303*14675a02SAndroid Build Coastguard Worker  Args:
304*14675a02SAndroid Build Coastguard Worker    aggregation_comp: The aggregation computation.
305*14675a02SAndroid Build Coastguard Worker    names: A dictionary describing how to map the first element of the path to a
306*14675a02SAndroid Build Coastguard Worker      top-level name.
307*14675a02SAndroid Build Coastguard Worker
308*14675a02SAndroid Build Coastguard Worker  Returns:
309*14675a02SAndroid Build Coastguard Worker    A `list` where the ith entry represents the input tensor specs for the
310*14675a02SAndroid Build Coastguard Worker    ith intrinsic in the aggregation computation. The ith entry is itself a list
311*14675a02SAndroid Build Coastguard Worker    where the jth entry represents the input tensor specs for the jth argument
312*14675a02SAndroid Build Coastguard Worker    of the ith intrinsic in the aggregation computation.
313*14675a02SAndroid Build Coastguard Worker
314*14675a02SAndroid Build Coastguard Worker  Raises:
315*14675a02SAndroid Build Coastguard Worker    TypeError: If the argument is of the wrong type.
316*14675a02SAndroid Build Coastguard Worker    ValueError: If the argument contains an unexpected
317*14675a02SAndroid Build Coastguard Worker      `building_block.Selection` index.
318*14675a02SAndroid Build Coastguard Worker  """
319*14675a02SAndroid Build Coastguard Worker
320*14675a02SAndroid Build Coastguard Worker  def _get_selection_path(
321*14675a02SAndroid Build Coastguard Worker      selection: tff.framework.ComputationBuildingBlock,
322*14675a02SAndroid Build Coastguard Worker  ) -> list[int]:
323*14675a02SAndroid Build Coastguard Worker    """Gets the list of selection indices for a building_blocks.Selection."""
324*14675a02SAndroid Build Coastguard Worker
325*14675a02SAndroid Build Coastguard Worker    path = []
326*14675a02SAndroid Build Coastguard Worker    while selection.is_selection():
327*14675a02SAndroid Build Coastguard Worker      path.append(selection.index)  # pytype: disable=attribute-error
328*14675a02SAndroid Build Coastguard Worker      selection = selection.source  # pytype: disable=attribute-error
329*14675a02SAndroid Build Coastguard Worker    # In ASTs like x[0][1], we'll see the last (outermost) selection first.
330*14675a02SAndroid Build Coastguard Worker    path.reverse()
331*14675a02SAndroid Build Coastguard Worker    return path
332*14675a02SAndroid Build Coastguard Worker
333*14675a02SAndroid Build Coastguard Worker  def _get_input_tensor_specs_for_aggregation_arg(
334*14675a02SAndroid Build Coastguard Worker      value: tff.framework.ComputationBuildingBlock, names: dict[int, str]
335*14675a02SAndroid Build Coastguard Worker  ) -> list[tf.TensorSpec]:
336*14675a02SAndroid Build Coastguard Worker    """Gets the input TensorSpecs for a single intrinsic argument."""
337*14675a02SAndroid Build Coastguard Worker
338*14675a02SAndroid Build Coastguard Worker    # An intrinsic arg may be a `building_block.Selection` or a (potentially
339*14675a02SAndroid Build Coastguard Worker    # nested) struct of `building_block.Selection`s. Start by creating a
340*14675a02SAndroid Build Coastguard Worker    # flattened list of the `building_block.Selection`s.
341*14675a02SAndroid Build Coastguard Worker    inner_values = []
342*14675a02SAndroid Build Coastguard Worker    if value.is_struct():
343*14675a02SAndroid Build Coastguard Worker      inner_values = tff.structure.flatten(value)
344*14675a02SAndroid Build Coastguard Worker    else:
345*14675a02SAndroid Build Coastguard Worker      inner_values = [value]
346*14675a02SAndroid Build Coastguard Worker
347*14675a02SAndroid Build Coastguard Worker    # For each `building_block.Selection`, reconstruct the tensor name that
348*14675a02SAndroid Build Coastguard Worker    # will be used to supply that value. The first index of the selection path
349*14675a02SAndroid Build Coastguard Worker    # indicates whether the tensor will be coming from the intermediate state
350*14675a02SAndroid Build Coastguard Worker    # checkpoint (0) or from the client checkpoint (1), since TFF condenses
351*14675a02SAndroid Build Coastguard Worker    # daf.client_to_server_aggregation(temp_server_state, client_update)
352*14675a02SAndroid Build Coastguard Worker    # into a 1-arg function. Since the tensors within the checkpoints
353*14675a02SAndroid Build Coastguard Worker    # corresponding to temp_server_state and work_at_clients will be named using
354*14675a02SAndroid Build Coastguard Worker    # variable_names_from_type, which uses a simple filepath-like naming pattern
355*14675a02SAndroid Build Coastguard Worker    # to refer to the tensors within a struct, we can reconstruct the relevant
356*14675a02SAndroid Build Coastguard Worker    # tensor name by concatenating together the remaining indices of each
357*14675a02SAndroid Build Coastguard Worker    # selection path.
358*14675a02SAndroid Build Coastguard Worker    tensor_specs = []
359*14675a02SAndroid Build Coastguard Worker    for inner_value in inner_values:
360*14675a02SAndroid Build Coastguard Worker      inner_value.check_selection()
361*14675a02SAndroid Build Coastguard Worker      path = _get_selection_path(inner_value)
362*14675a02SAndroid Build Coastguard Worker      arg_index = path[0]
363*14675a02SAndroid Build Coastguard Worker      if arg_index in names:
364*14675a02SAndroid Build Coastguard Worker        prefix = names[arg_index]
365*14675a02SAndroid Build Coastguard Worker      else:
366*14675a02SAndroid Build Coastguard Worker        raise ValueError('Unexpected arg index for aggregation selection')
367*14675a02SAndroid Build Coastguard Worker      prefix += '/' + '/'.join([str(x) for x in path[1:]])
368*14675a02SAndroid Build Coastguard Worker      tensor_specs.extend(
369*14675a02SAndroid Build Coastguard Worker          get_flattened_tensor_specs(inner_value.type_signature, name=prefix)
370*14675a02SAndroid Build Coastguard Worker      )
371*14675a02SAndroid Build Coastguard Worker
372*14675a02SAndroid Build Coastguard Worker    return tensor_specs
373*14675a02SAndroid Build Coastguard Worker
374*14675a02SAndroid Build Coastguard Worker  grouped_input_tensor_specs = []
375*14675a02SAndroid Build Coastguard Worker
376*14675a02SAndroid Build Coastguard Worker  for _, local_value in aggregation_comp.result.locals:  # pytype: disable=attribute-error
377*14675a02SAndroid Build Coastguard Worker    local_value.check_call()
378*14675a02SAndroid Build Coastguard Worker    local_value.function.check_intrinsic()
379*14675a02SAndroid Build Coastguard Worker    assert local_value.function.intrinsic_def().aggregation_kind
380*14675a02SAndroid Build Coastguard Worker
381*14675a02SAndroid Build Coastguard Worker    # Collect the input TensorFlowSpecs for each argument for this intrinsic.
382*14675a02SAndroid Build Coastguard Worker    input_tensor_specs_for_intrinsic = []
383*14675a02SAndroid Build Coastguard Worker    if (
384*14675a02SAndroid Build Coastguard Worker        local_value.function.intrinsic_def().type_signature.parameter.is_struct()
385*14675a02SAndroid Build Coastguard Worker    ):
386*14675a02SAndroid Build Coastguard Worker      for element in local_value.argument.children():
387*14675a02SAndroid Build Coastguard Worker        input_tensor_specs_for_intrinsic.append(
388*14675a02SAndroid Build Coastguard Worker            _get_input_tensor_specs_for_aggregation_arg(element, names)
389*14675a02SAndroid Build Coastguard Worker        )
390*14675a02SAndroid Build Coastguard Worker    else:
391*14675a02SAndroid Build Coastguard Worker      input_tensor_specs_for_intrinsic.append(
392*14675a02SAndroid Build Coastguard Worker          _get_input_tensor_specs_for_aggregation_arg(
393*14675a02SAndroid Build Coastguard Worker              local_value.argument, names
394*14675a02SAndroid Build Coastguard Worker          )
395*14675a02SAndroid Build Coastguard Worker      )
396*14675a02SAndroid Build Coastguard Worker
397*14675a02SAndroid Build Coastguard Worker    grouped_input_tensor_specs.append(input_tensor_specs_for_intrinsic)
398*14675a02SAndroid Build Coastguard Worker
399*14675a02SAndroid Build Coastguard Worker  return grouped_input_tensor_specs
400*14675a02SAndroid Build Coastguard Worker
401*14675a02SAndroid Build Coastguard Worker
402*14675a02SAndroid Build Coastguard Workerdef get_grouped_output_tensor_specs_for_aggregations(
403*14675a02SAndroid Build Coastguard Worker    aggregation_comp: tff.framework.ComputationBuildingBlock,
404*14675a02SAndroid Build Coastguard Worker) -> list[list[tf.TensorSpec]]:
405*14675a02SAndroid Build Coastguard Worker  """Gets the output TensorSpecs for an aggregation computation.
406*14675a02SAndroid Build Coastguard Worker
407*14675a02SAndroid Build Coastguard Worker  This function can be used to generate the TensorSpecs that are assigned
408*14675a02SAndroid Build Coastguard Worker  to the output_tensors value in ServerAggregationConfig messages to represent
409*14675a02SAndroid Build Coastguard Worker  the aggregation intrinsic calls in
410*14675a02SAndroid Build Coastguard Worker  DistributeAggregateForm.client_to_server_aggregation.
411*14675a02SAndroid Build Coastguard Worker
412*14675a02SAndroid Build Coastguard Worker  It derives the tensor name(s) for each intrinsic output argument by following
413*14675a02SAndroid Build Coastguard Worker  naming logic similar to `variable_names_from_type`. It must produce tensor
414*14675a02SAndroid Build Coastguard Worker  names that match the tensor names that are expected by the post-aggregation
415*14675a02SAndroid Build Coastguard Worker  computation.
416*14675a02SAndroid Build Coastguard Worker
417*14675a02SAndroid Build Coastguard Worker  Args:
418*14675a02SAndroid Build Coastguard Worker    aggregation_comp: The aggregation computation.
419*14675a02SAndroid Build Coastguard Worker
420*14675a02SAndroid Build Coastguard Worker  Returns:
421*14675a02SAndroid Build Coastguard Worker    A list where the ith entry represents the output tensor specs for the ith
422*14675a02SAndroid Build Coastguard Worker    intrinsic in the aggregation computation.
423*14675a02SAndroid Build Coastguard Worker
424*14675a02SAndroid Build Coastguard Worker  Raises:
425*14675a02SAndroid Build Coastguard Worker    TypeError: If the argument is of the wrong type.
426*14675a02SAndroid Build Coastguard Worker  """
427*14675a02SAndroid Build Coastguard Worker  # TensorflowSpecs for all the intrinsic results. These TensorflowSpecs must
428*14675a02SAndroid Build Coastguard Worker  # have names that mirror the result of calling variable_names_from_type on
429*14675a02SAndroid Build Coastguard Worker  # the output type of DistributeAggregateForm.client_to_server_aggregation
430*14675a02SAndroid Build Coastguard Worker  # (which is the same as the type of the aggregation result input arg in
431*14675a02SAndroid Build Coastguard Worker  # DistributeAggregateForm.server_result).
432*14675a02SAndroid Build Coastguard Worker  output_tensor_specs = get_flattened_tensor_specs(
433*14675a02SAndroid Build Coastguard Worker      tff.StructType([aggregation_comp.type_signature.result]),
434*14675a02SAndroid Build Coastguard Worker      name='intermediate_update',
435*14675a02SAndroid Build Coastguard Worker  )
436*14675a02SAndroid Build Coastguard Worker  output_tensor_spec_index = 0
437*14675a02SAndroid Build Coastguard Worker
438*14675a02SAndroid Build Coastguard Worker  grouped_output_tensor_specs = []
439*14675a02SAndroid Build Coastguard Worker
440*14675a02SAndroid Build Coastguard Worker  for _, local_value in aggregation_comp.result.locals:  # pytype: disable=attribute-error
441*14675a02SAndroid Build Coastguard Worker    local_value.check_call()
442*14675a02SAndroid Build Coastguard Worker    local_value.function.check_intrinsic()
443*14675a02SAndroid Build Coastguard Worker    local_value.type_signature.check_federated()
444*14675a02SAndroid Build Coastguard Worker    assert local_value.function.intrinsic_def().aggregation_kind
445*14675a02SAndroid Build Coastguard Worker
446*14675a02SAndroid Build Coastguard Worker    tensor_specs = []
447*14675a02SAndroid Build Coastguard Worker    # If the output is a struct, select the appropriate number of
448*14675a02SAndroid Build Coastguard Worker    # TensorflowSpecs.
449*14675a02SAndroid Build Coastguard Worker    if local_value.type_signature.member.is_struct():
450*14675a02SAndroid Build Coastguard Worker      num_specs = len(tff.structure.flatten(local_value.type_signature.member))
451*14675a02SAndroid Build Coastguard Worker      tensor_specs = output_tensor_specs[
452*14675a02SAndroid Build Coastguard Worker          output_tensor_spec_index : output_tensor_spec_index + num_specs
453*14675a02SAndroid Build Coastguard Worker      ]
454*14675a02SAndroid Build Coastguard Worker      output_tensor_spec_index += num_specs
455*14675a02SAndroid Build Coastguard Worker    else:
456*14675a02SAndroid Build Coastguard Worker      tensor_specs.append(output_tensor_specs[output_tensor_spec_index])
457*14675a02SAndroid Build Coastguard Worker      output_tensor_spec_index += 1
458*14675a02SAndroid Build Coastguard Worker    grouped_output_tensor_specs.append(tensor_specs)
459*14675a02SAndroid Build Coastguard Worker
460*14675a02SAndroid Build Coastguard Worker  return grouped_output_tensor_specs
461