xref: /aosp_15_r20/external/federated-compute/fcp/artifact_building/checkpoint_utils.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC
2*14675a02SAndroid Build Coastguard Worker#
3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License");
4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License.
5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at
6*14675a02SAndroid Build Coastguard Worker#
7*14675a02SAndroid Build Coastguard Worker#      http://www.apache.org/licenses/LICENSE-2.0
8*14675a02SAndroid Build Coastguard Worker#
9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software
10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS,
11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and
13*14675a02SAndroid Build Coastguard Worker# limitations under the License.
14*14675a02SAndroid Build Coastguard Worker"""Helper methods for working with demo server checkpoints."""
15*14675a02SAndroid Build Coastguard Worker
16*14675a02SAndroid Build Coastguard Workerimport collections
17*14675a02SAndroid Build Coastguard Workerfrom collections.abc import Callable, Iterable, Mapping
18*14675a02SAndroid Build Coastguard Workerfrom typing import Any, Optional, Union
19*14675a02SAndroid Build Coastguard Worker
20*14675a02SAndroid Build Coastguard Workerimport numpy as np
21*14675a02SAndroid Build Coastguard Workerimport tensorflow as tf
22*14675a02SAndroid Build Coastguard Workerimport tensorflow_federated as tff
23*14675a02SAndroid Build Coastguard Worker
24*14675a02SAndroid Build Coastguard Workerfrom fcp.artifact_building import artifact_constants
25*14675a02SAndroid Build Coastguard Workerfrom fcp.artifact_building import tensor_utils
26*14675a02SAndroid Build Coastguard Workerfrom fcp.artifact_building import type_checks
27*14675a02SAndroid Build Coastguard Workerfrom fcp.artifact_building import variable_helpers
28*14675a02SAndroid Build Coastguard Workerfrom fcp.protos import plan_pb2
29*14675a02SAndroid Build Coastguard Worker
30*14675a02SAndroid Build Coastguard WorkerSAVE_SERVER_SAVEPOINT_NAME = 'save_server_savepoint'
31*14675a02SAndroid Build Coastguard Worker
32*14675a02SAndroid Build Coastguard Worker
33*14675a02SAndroid Build Coastguard Workerdef create_server_checkpoint_vars_and_savepoint(
34*14675a02SAndroid Build Coastguard Worker    *,
35*14675a02SAndroid Build Coastguard Worker    server_state_type: tff.StructType,
36*14675a02SAndroid Build Coastguard Worker    server_metrics_type: Optional[tff.StructType] = None,
37*14675a02SAndroid Build Coastguard Worker    write_metrics_to_checkpoint: bool = True,
38*14675a02SAndroid Build Coastguard Worker    additional_checkpoint_metadata_var_fn: Optional[
39*14675a02SAndroid Build Coastguard Worker        Callable[[tff.StructType, tff.StructType, bool], list[tf.Variable]]
40*14675a02SAndroid Build Coastguard Worker    ] = None,
41*14675a02SAndroid Build Coastguard Worker) -> tuple[
42*14675a02SAndroid Build Coastguard Worker    list[tf.Variable],
43*14675a02SAndroid Build Coastguard Worker    list[tf.Variable],
44*14675a02SAndroid Build Coastguard Worker    list[tf.Variable],
45*14675a02SAndroid Build Coastguard Worker    plan_pb2.CheckpointOp,
46*14675a02SAndroid Build Coastguard Worker]:
47*14675a02SAndroid Build Coastguard Worker  """Creates tf.Variables for a server checkpoint and the associated savepoint.
48*14675a02SAndroid Build Coastguard Worker
49*14675a02SAndroid Build Coastguard Worker  The variables and the associated saver are constructed in the default graph.
50*14675a02SAndroid Build Coastguard Worker
51*14675a02SAndroid Build Coastguard Worker  For now, only `server_state_type` is required. If metrics are to be saved in
52*14675a02SAndroid Build Coastguard Worker  the server checkpoint, `server_metrics_type` and `server_result_type` must
53*14675a02SAndroid Build Coastguard Worker  be provided. `server_state_type` refers to the server state portion of the
54*14675a02SAndroid Build Coastguard Worker  checkpoint and is used in the `Restore` op of the savepoint. The
55*14675a02SAndroid Build Coastguard Worker  `server_metrics_type` refers to the metrics saved in the checkpoint, and is
56*14675a02SAndroid Build Coastguard Worker  not used in the `Restore` op of the savepoint. `server_result_type` refers to
57*14675a02SAndroid Build Coastguard Worker  the complete round result structure stored in the checkpoint for a round.
58*14675a02SAndroid Build Coastguard Worker
59*14675a02SAndroid Build Coastguard Worker  Args:
60*14675a02SAndroid Build Coastguard Worker    server_state_type: A `tff.Type` with the type signature of the state. This
61*14675a02SAndroid Build Coastguard Worker      is used to construct the server state variable names stored in the
62*14675a02SAndroid Build Coastguard Worker      checkpoint and is used to create the metadata variables for the checkpoint
63*14675a02SAndroid Build Coastguard Worker      if `server_result_type` is not provided.
64*14675a02SAndroid Build Coastguard Worker    server_metrics_type: Optional. A `tff.Type` with the type signature of the
65*14675a02SAndroid Build Coastguard Worker      metrics. If provided, this is used to construct the metric variable names
66*14675a02SAndroid Build Coastguard Worker      that are stored in the checkpoint.
67*14675a02SAndroid Build Coastguard Worker    write_metrics_to_checkpoint: If False, revert to legacy behavior where
68*14675a02SAndroid Build Coastguard Worker      metrics and other non-state values were handled by post-processing
69*14675a02SAndroid Build Coastguard Worker      separate from the outputted checkpoint.
70*14675a02SAndroid Build Coastguard Worker    additional_checkpoint_metadata_var_fn: An optional method that takes in the
71*14675a02SAndroid Build Coastguard Worker      server_state_type, server_metrics_type, and write_metrics_to_checkpoint to
72*14675a02SAndroid Build Coastguard Worker      produce additional metadata variables.
73*14675a02SAndroid Build Coastguard Worker
74*14675a02SAndroid Build Coastguard Worker  Returns:
75*14675a02SAndroid Build Coastguard Worker    A tuple `(state_vars, metric_vars, metadata_vars, savepoint)`:
76*14675a02SAndroid Build Coastguard Worker    - `state_vars` is a Python `list` of variables that hold the state.
77*14675a02SAndroid Build Coastguard Worker    - `metric_vars` is a Python `list` of variables that hold the metrics.
78*14675a02SAndroid Build Coastguard Worker    - `metadata_vars` is a Python `list` of variables that hold optional
79*14675a02SAndroid Build Coastguard Worker      metadata.
80*14675a02SAndroid Build Coastguard Worker    - `savepoint` is the associated savepoint, i.e., an instance of
81*14675a02SAndroid Build Coastguard Worker      `plan_pb2.CheckpointOp` with a saver configured for saving the
82*14675a02SAndroid Build Coastguard Worker      `state_vars`, `metadata_vars`, and, if write_metrics_to_checkpoint is
83*14675a02SAndroid Build Coastguard Worker      True, `metric_vars`, and restoring the `state_vars` and
84*14675a02SAndroid Build Coastguard Worker      `metadata_vars`.
85*14675a02SAndroid Build Coastguard Worker  """
86*14675a02SAndroid Build Coastguard Worker  has_metrics = False
87*14675a02SAndroid Build Coastguard Worker  metric_vars = []
88*14675a02SAndroid Build Coastguard Worker  save_tensor_name = None
89*14675a02SAndroid Build Coastguard Worker  type_checks.check_type(server_state_type, tff.Type, name='server_state_type')
90*14675a02SAndroid Build Coastguard Worker  state_vars = variable_helpers.create_vars_for_tff_type(
91*14675a02SAndroid Build Coastguard Worker      server_state_type, artifact_constants.SERVER_STATE_VAR_PREFIX
92*14675a02SAndroid Build Coastguard Worker  )
93*14675a02SAndroid Build Coastguard Worker  var_names = list(map(tensor_utils.bare_name, state_vars))
94*14675a02SAndroid Build Coastguard Worker  metadata_vars = []
95*14675a02SAndroid Build Coastguard Worker  if server_metrics_type is not None:
96*14675a02SAndroid Build Coastguard Worker    type_checks.check_type(
97*14675a02SAndroid Build Coastguard Worker        server_metrics_type, tff.Type, name='server_metrics_type'
98*14675a02SAndroid Build Coastguard Worker    )
99*14675a02SAndroid Build Coastguard Worker    metric_vars = variable_helpers.create_vars_for_tff_type(
100*14675a02SAndroid Build Coastguard Worker        server_metrics_type, artifact_constants.SERVER_METRICS_VAR_PREFIX
101*14675a02SAndroid Build Coastguard Worker    )
102*14675a02SAndroid Build Coastguard Worker    if additional_checkpoint_metadata_var_fn:
103*14675a02SAndroid Build Coastguard Worker      metadata_vars = additional_checkpoint_metadata_var_fn(
104*14675a02SAndroid Build Coastguard Worker          state_vars, metric_vars, write_metrics_to_checkpoint
105*14675a02SAndroid Build Coastguard Worker      )
106*14675a02SAndroid Build Coastguard Worker
107*14675a02SAndroid Build Coastguard Worker    has_metrics = bool(tff.structure.flatten(server_metrics_type))
108*14675a02SAndroid Build Coastguard Worker    if has_metrics and write_metrics_to_checkpoint:
109*14675a02SAndroid Build Coastguard Worker      var_names.extend(list(map(tensor_utils.bare_name, metric_vars)))
110*14675a02SAndroid Build Coastguard Worker
111*14675a02SAndroid Build Coastguard Worker      temp_saver_for_all_vars = create_deterministic_saver(
112*14675a02SAndroid Build Coastguard Worker          var_list=state_vars + metadata_vars + metric_vars,
113*14675a02SAndroid Build Coastguard Worker          name=SAVE_SERVER_SAVEPOINT_NAME,
114*14675a02SAndroid Build Coastguard Worker      )
115*14675a02SAndroid Build Coastguard Worker      temp_saver_def = temp_saver_for_all_vars.as_saver_def()
116*14675a02SAndroid Build Coastguard Worker      save_tensor_name = temp_saver_def.save_tensor_name
117*14675a02SAndroid Build Coastguard Worker  else:
118*14675a02SAndroid Build Coastguard Worker    if additional_checkpoint_metadata_var_fn:
119*14675a02SAndroid Build Coastguard Worker      metadata_vars = additional_checkpoint_metadata_var_fn(
120*14675a02SAndroid Build Coastguard Worker          state_vars, None, write_metrics_to_checkpoint
121*14675a02SAndroid Build Coastguard Worker      )
122*14675a02SAndroid Build Coastguard Worker
123*14675a02SAndroid Build Coastguard Worker  saver = create_deterministic_saver(
124*14675a02SAndroid Build Coastguard Worker      var_list=state_vars + metadata_vars,
125*14675a02SAndroid Build Coastguard Worker      name='{}_savepoint'.format(artifact_constants.SERVER_STATE_VAR_PREFIX),
126*14675a02SAndroid Build Coastguard Worker  )
127*14675a02SAndroid Build Coastguard Worker  savepoint = plan_pb2.CheckpointOp()
128*14675a02SAndroid Build Coastguard Worker  savepoint.saver_def.CopyFrom(saver.as_saver_def())
129*14675a02SAndroid Build Coastguard Worker
130*14675a02SAndroid Build Coastguard Worker  if save_tensor_name is not None:
131*14675a02SAndroid Build Coastguard Worker    # Replace the save_tensor_name to the one in
132*14675a02SAndroid Build Coastguard Worker    # temp_saver_for_all_vars so that we are additionally saving metrics vars
133*14675a02SAndroid Build Coastguard Worker    # in the checkpoint that don't need to be restored as part of the input
134*14675a02SAndroid Build Coastguard Worker    # computation state.
135*14675a02SAndroid Build Coastguard Worker    # Once we create the server GraphDef, we will edit the GraphDef directly
136*14675a02SAndroid Build Coastguard Worker    # to ensure the input filename links to the filename tensor from the
137*14675a02SAndroid Build Coastguard Worker    # `savepoint`.
138*14675a02SAndroid Build Coastguard Worker    savepoint.saver_def.save_tensor_name = save_tensor_name
139*14675a02SAndroid Build Coastguard Worker  return state_vars, metric_vars, metadata_vars, savepoint
140*14675a02SAndroid Build Coastguard Worker
141*14675a02SAndroid Build Coastguard Worker
142*14675a02SAndroid Build Coastguard Workerdef create_state_vars_and_savepoint(
143*14675a02SAndroid Build Coastguard Worker    type_spec: variable_helpers.AllowedTffTypes, name: str
144*14675a02SAndroid Build Coastguard Worker) -> tuple[list[tf.Variable], plan_pb2.CheckpointOp]:
145*14675a02SAndroid Build Coastguard Worker  """Creates state variables and their savepoint as a `plan_pb2.CheckpointOp`.
146*14675a02SAndroid Build Coastguard Worker
147*14675a02SAndroid Build Coastguard Worker  The variables and the associated saver are constructed in the default graph.
148*14675a02SAndroid Build Coastguard Worker
149*14675a02SAndroid Build Coastguard Worker  Args:
150*14675a02SAndroid Build Coastguard Worker    type_spec: An instance of `tff.Type` with the type signature of the state.
151*14675a02SAndroid Build Coastguard Worker    name: The string to use as a basis for naming the vars and the saver. The
152*14675a02SAndroid Build Coastguard Worker      vars will be under `${name}_state`, and saver under `${name}_savepoint`.
153*14675a02SAndroid Build Coastguard Worker
154*14675a02SAndroid Build Coastguard Worker  Returns:
155*14675a02SAndroid Build Coastguard Worker    A tuple `(vars, savepoint)`, where `vars` is a Python `list` of variables
156*14675a02SAndroid Build Coastguard Worker    that hold the state, and `savepoint` is the associated savepoint, i.e.,
157*14675a02SAndroid Build Coastguard Worker    an instance of `plan_pb2.CheckpointOp` with a saver configured for saving
158*14675a02SAndroid Build Coastguard Worker    and restoring the `vars`.
159*14675a02SAndroid Build Coastguard Worker
160*14675a02SAndroid Build Coastguard Worker  Raises:
161*14675a02SAndroid Build Coastguard Worker    ValueError: If the name is empty.
162*14675a02SAndroid Build Coastguard Worker  """
163*14675a02SAndroid Build Coastguard Worker  state_vars, saver = create_state_vars_and_saver(type_spec, name)
164*14675a02SAndroid Build Coastguard Worker  savepoint = plan_pb2.CheckpointOp()
165*14675a02SAndroid Build Coastguard Worker  savepoint.saver_def.CopyFrom(saver.as_saver_def())
166*14675a02SAndroid Build Coastguard Worker  return state_vars, savepoint
167*14675a02SAndroid Build Coastguard Worker
168*14675a02SAndroid Build Coastguard Worker
169*14675a02SAndroid Build Coastguard Workerdef create_state_vars_and_saver(
170*14675a02SAndroid Build Coastguard Worker    type_spec: variable_helpers.AllowedTffTypes, name: str
171*14675a02SAndroid Build Coastguard Worker) -> tuple[list[tf.Variable], tf.compat.v1.train.Saver]:
172*14675a02SAndroid Build Coastguard Worker  """Creates state variables and the associated saver.
173*14675a02SAndroid Build Coastguard Worker
174*14675a02SAndroid Build Coastguard Worker  The variables and the associated saver are constructed in the default graph.
175*14675a02SAndroid Build Coastguard Worker
176*14675a02SAndroid Build Coastguard Worker  Args:
177*14675a02SAndroid Build Coastguard Worker    type_spec: An instance of `tff.Type` with the type signature of the state.
178*14675a02SAndroid Build Coastguard Worker    name: The string to use as a basis for naming the vars and the saver. The
179*14675a02SAndroid Build Coastguard Worker      vars will be under `${name}_state`, and saver under `${name}_savepoint`.
180*14675a02SAndroid Build Coastguard Worker
181*14675a02SAndroid Build Coastguard Worker  Returns:
182*14675a02SAndroid Build Coastguard Worker    A tuple `(vars, savepoint)`, where `vars` is a Python `list` of variables
183*14675a02SAndroid Build Coastguard Worker    that hold the state, and `savepoint` is the associated
184*14675a02SAndroid Build Coastguard Worker    `tf.compat.v1.train.Saver`.
185*14675a02SAndroid Build Coastguard Worker
186*14675a02SAndroid Build Coastguard Worker  Raises:
187*14675a02SAndroid Build Coastguard Worker    ValueError: If the name is empty.
188*14675a02SAndroid Build Coastguard Worker  """
189*14675a02SAndroid Build Coastguard Worker  type_checks.check_type(type_spec, tff.Type, name='type_spec')
190*14675a02SAndroid Build Coastguard Worker  type_checks.check_type(name, str, name='name')
191*14675a02SAndroid Build Coastguard Worker  if not name:
192*14675a02SAndroid Build Coastguard Worker    raise ValueError('Name cannot be empty.')
193*14675a02SAndroid Build Coastguard Worker  state_vars = variable_helpers.create_vars_for_tff_type(type_spec, name)
194*14675a02SAndroid Build Coastguard Worker  saver = create_deterministic_saver(
195*14675a02SAndroid Build Coastguard Worker      state_vars, name='{}_savepoint'.format(name)
196*14675a02SAndroid Build Coastguard Worker  )
197*14675a02SAndroid Build Coastguard Worker  return state_vars, saver
198*14675a02SAndroid Build Coastguard Worker
199*14675a02SAndroid Build Coastguard Worker
200*14675a02SAndroid Build Coastguard Workerdef restore_tensors_from_savepoint(
201*14675a02SAndroid Build Coastguard Worker    tensor_specs: Iterable[tf.TensorSpec], filepath_tensor: tf.Tensor
202*14675a02SAndroid Build Coastguard Worker) -> list[tf.Tensor]:
203*14675a02SAndroid Build Coastguard Worker  """Restores tensors from a checkpoint designated by a tensor filepath.
204*14675a02SAndroid Build Coastguard Worker
205*14675a02SAndroid Build Coastguard Worker  Args:
206*14675a02SAndroid Build Coastguard Worker    tensor_specs: A `list` of `tf.TensorSpec`s with the names and dtypes of the
207*14675a02SAndroid Build Coastguard Worker      tensors to restore.
208*14675a02SAndroid Build Coastguard Worker    filepath_tensor: A placeholder tensor that contains file names with a given
209*14675a02SAndroid Build Coastguard Worker      pattern.
210*14675a02SAndroid Build Coastguard Worker
211*14675a02SAndroid Build Coastguard Worker  Returns:
212*14675a02SAndroid Build Coastguard Worker    A list of restored tensors.
213*14675a02SAndroid Build Coastguard Worker  """
214*14675a02SAndroid Build Coastguard Worker  return [
215*14675a02SAndroid Build Coastguard Worker      tensor_utils.restore(
216*14675a02SAndroid Build Coastguard Worker          filepath_tensor, tensor_utils.bare_name(spec.name), spec.dtype
217*14675a02SAndroid Build Coastguard Worker      )
218*14675a02SAndroid Build Coastguard Worker      for spec in tensor_specs
219*14675a02SAndroid Build Coastguard Worker  ]
220*14675a02SAndroid Build Coastguard Worker
221*14675a02SAndroid Build Coastguard Worker
222*14675a02SAndroid Build Coastguard Workerdef create_deterministic_saver(
223*14675a02SAndroid Build Coastguard Worker    var_list: Union[Iterable[tf.Variable], Mapping[str, tf.Variable]],
224*14675a02SAndroid Build Coastguard Worker    *args,
225*14675a02SAndroid Build Coastguard Worker    **kwargs,
226*14675a02SAndroid Build Coastguard Worker) -> tf.compat.v1.train.Saver:
227*14675a02SAndroid Build Coastguard Worker  """Creates a `tf.compat.v1.Saver` that is deterministic.
228*14675a02SAndroid Build Coastguard Worker
229*14675a02SAndroid Build Coastguard Worker  This method sorts the `var_list` to ensure a deterministic ordering which
230*14675a02SAndroid Build Coastguard Worker  in turn ensures a deterministic checkpoint.
231*14675a02SAndroid Build Coastguard Worker
232*14675a02SAndroid Build Coastguard Worker  Uses `tf.compat.v1.train.SaverDef.V1` version for writing checkpoints.
233*14675a02SAndroid Build Coastguard Worker
234*14675a02SAndroid Build Coastguard Worker  Args:
235*14675a02SAndroid Build Coastguard Worker    var_list: An `Iterable` or `str` keyed `Mapping` of `tf.Variables`. In the
236*14675a02SAndroid Build Coastguard Worker      case of a `dict`, the keys become the names of the checkpoint variables
237*14675a02SAndroid Build Coastguard Worker      (rather than reading the names off the `tf.Variable` values).
238*14675a02SAndroid Build Coastguard Worker    *args: Positional arguments forwarded to the `tf.compat.v1.train.Saver`
239*14675a02SAndroid Build Coastguard Worker      constructor.
240*14675a02SAndroid Build Coastguard Worker    **kwargs: Keyword arguments forwarded to the `tf.compat.v1.train.Saver`
241*14675a02SAndroid Build Coastguard Worker      constructor.
242*14675a02SAndroid Build Coastguard Worker
243*14675a02SAndroid Build Coastguard Worker  Returns:
244*14675a02SAndroid Build Coastguard Worker    A `tf.compat.v1.train.Saver` instance.
245*14675a02SAndroid Build Coastguard Worker  """
246*14675a02SAndroid Build Coastguard Worker  if isinstance(var_list, collections.abc.Mapping):
247*14675a02SAndroid Build Coastguard Worker    determinisic_names = collections.OrderedDict(sorted(var_list.items()))
248*14675a02SAndroid Build Coastguard Worker  elif isinstance(var_list, collections.abc.Iterable):
249*14675a02SAndroid Build Coastguard Worker    determinisic_names = sorted(var_list, key=lambda v: v.name)
250*14675a02SAndroid Build Coastguard Worker  else:
251*14675a02SAndroid Build Coastguard Worker    raise ValueError(
252*14675a02SAndroid Build Coastguard Worker        'Do not know how to make a deterministic saver for '
253*14675a02SAndroid Build Coastguard Worker        '`var_list` of type [{t}]. Must be a Mapping or Sequence'.format(
254*14675a02SAndroid Build Coastguard Worker            t=type(var_list)
255*14675a02SAndroid Build Coastguard Worker        )
256*14675a02SAndroid Build Coastguard Worker    )
257*14675a02SAndroid Build Coastguard Worker  return tf.compat.v1.train.Saver(
258*14675a02SAndroid Build Coastguard Worker      determinisic_names,
259*14675a02SAndroid Build Coastguard Worker      write_version=tf.compat.v1.train.SaverDef.V1,
260*14675a02SAndroid Build Coastguard Worker      *args,
261*14675a02SAndroid Build Coastguard Worker      **kwargs,
262*14675a02SAndroid Build Coastguard Worker  )
263*14675a02SAndroid Build Coastguard Worker
264*14675a02SAndroid Build Coastguard Worker
265*14675a02SAndroid Build Coastguard Workerdef tff_type_to_dtype_list(
266*14675a02SAndroid Build Coastguard Worker    tff_type: variable_helpers.AllowedTffTypes,
267*14675a02SAndroid Build Coastguard Worker) -> list[tf.DType]:
268*14675a02SAndroid Build Coastguard Worker  """Creates a flat list of `tf.DType`s for tensors in a `tff.Type`.
269*14675a02SAndroid Build Coastguard Worker
270*14675a02SAndroid Build Coastguard Worker  Args:
271*14675a02SAndroid Build Coastguard Worker    tff_type: Either a `tff.StructType`, `tff.FederatedType`, or a
272*14675a02SAndroid Build Coastguard Worker      `tff.TensorType` object.
273*14675a02SAndroid Build Coastguard Worker
274*14675a02SAndroid Build Coastguard Worker  Returns:
275*14675a02SAndroid Build Coastguard Worker    A flat list of `tf.DType`s.
276*14675a02SAndroid Build Coastguard Worker  """
277*14675a02SAndroid Build Coastguard Worker  type_checks.check_type(
278*14675a02SAndroid Build Coastguard Worker      tff_type, (tff.TensorType, tff.FederatedType, tff.StructType)
279*14675a02SAndroid Build Coastguard Worker  )
280*14675a02SAndroid Build Coastguard Worker  if isinstance(tff_type, tff.TensorType):
281*14675a02SAndroid Build Coastguard Worker    return [tff_type.dtype]
282*14675a02SAndroid Build Coastguard Worker  elif isinstance(tff_type, tff.FederatedType):
283*14675a02SAndroid Build Coastguard Worker    return tff_type_to_dtype_list(tff_type.member)
284*14675a02SAndroid Build Coastguard Worker  else:  # tff.StructType
285*14675a02SAndroid Build Coastguard Worker    elem_list = []
286*14675a02SAndroid Build Coastguard Worker    for elem_type in tff_type:
287*14675a02SAndroid Build Coastguard Worker      elem_list.extend(tff_type_to_dtype_list(elem_type))
288*14675a02SAndroid Build Coastguard Worker    return elem_list
289*14675a02SAndroid Build Coastguard Worker
290*14675a02SAndroid Build Coastguard Worker
291*14675a02SAndroid Build Coastguard Workerdef tff_type_to_tensor_spec_list(
292*14675a02SAndroid Build Coastguard Worker    tff_type: variable_helpers.AllowedTffTypes,
293*14675a02SAndroid Build Coastguard Worker) -> list[tf.TensorSpec]:
294*14675a02SAndroid Build Coastguard Worker  """Creates a flat list of tensor specs for tensors in a `tff.Type`.
295*14675a02SAndroid Build Coastguard Worker
296*14675a02SAndroid Build Coastguard Worker  Args:
297*14675a02SAndroid Build Coastguard Worker    tff_type: Either a `tff.StructType`, `tff.FederatedType` or a
298*14675a02SAndroid Build Coastguard Worker      `tff.TensorType` object.
299*14675a02SAndroid Build Coastguard Worker
300*14675a02SAndroid Build Coastguard Worker  Returns:
301*14675a02SAndroid Build Coastguard Worker    A flat list of `tf.TensorSpec`s.
302*14675a02SAndroid Build Coastguard Worker  """
303*14675a02SAndroid Build Coastguard Worker  type_checks.check_type(
304*14675a02SAndroid Build Coastguard Worker      tff_type, (tff.TensorType, tff.FederatedType, tff.StructType)
305*14675a02SAndroid Build Coastguard Worker  )
306*14675a02SAndroid Build Coastguard Worker  if isinstance(tff_type, tff.TensorType):
307*14675a02SAndroid Build Coastguard Worker    return [tf.TensorSpec(tff_type.shape, dtype=tff_type.dtype)]
308*14675a02SAndroid Build Coastguard Worker  elif isinstance(tff_type, tff.FederatedType):
309*14675a02SAndroid Build Coastguard Worker    return tff_type_to_tensor_spec_list(tff_type.member)
310*14675a02SAndroid Build Coastguard Worker  else:  # tff.StructType
311*14675a02SAndroid Build Coastguard Worker    elem_list = []
312*14675a02SAndroid Build Coastguard Worker    for elem_type in tff_type:
313*14675a02SAndroid Build Coastguard Worker      elem_list.extend(tff_type_to_tensor_spec_list(elem_type))
314*14675a02SAndroid Build Coastguard Worker    return elem_list
315*14675a02SAndroid Build Coastguard Worker
316*14675a02SAndroid Build Coastguard Worker
317*14675a02SAndroid Build Coastguard Workerdef pack_tff_value(
318*14675a02SAndroid Build Coastguard Worker    tff_type: variable_helpers.AllowedTffTypes, value_list: Any
319*14675a02SAndroid Build Coastguard Worker) -> Any:
320*14675a02SAndroid Build Coastguard Worker  """Packs a list of values into a shape specified by a `tff.Type`.
321*14675a02SAndroid Build Coastguard Worker
322*14675a02SAndroid Build Coastguard Worker  Args:
323*14675a02SAndroid Build Coastguard Worker    tff_type: Either a `tff.StructType`, `tff.FederatedType`, or a
324*14675a02SAndroid Build Coastguard Worker      `tff.TensorType` object.
325*14675a02SAndroid Build Coastguard Worker    value_list: A flat list of `tf.Tensor` or `CheckpointTensorReference`.
326*14675a02SAndroid Build Coastguard Worker
327*14675a02SAndroid Build Coastguard Worker  Returns:
328*14675a02SAndroid Build Coastguard Worker    A Python container with a structure consistent with a `tff.Type`.
329*14675a02SAndroid Build Coastguard Worker
330*14675a02SAndroid Build Coastguard Worker  Raises:
331*14675a02SAndroid Build Coastguard Worker    ValueError: If the number of leaves in `tff_type` does not match the length
332*14675a02SAndroid Build Coastguard Worker    of `value_list`, or `tff_type` is of a disallowed type.
333*14675a02SAndroid Build Coastguard Worker  """
334*14675a02SAndroid Build Coastguard Worker  type_checks.check_type(
335*14675a02SAndroid Build Coastguard Worker      tff_type, (tff.TensorType, tff.FederatedType, tff.StructType)
336*14675a02SAndroid Build Coastguard Worker  )
337*14675a02SAndroid Build Coastguard Worker
338*14675a02SAndroid Build Coastguard Worker  # We must "unwrap" any FederatedTypes because the
339*14675a02SAndroid Build Coastguard Worker  # `tff.structure.pack_sequence_as` call below will fail to recurse into them.
340*14675a02SAndroid Build Coastguard Worker  # Instead, we remove all the FederatedTypes, because we're only trying to
341*14675a02SAndroid Build Coastguard Worker  # build up a Python tree structure that matches the struct/tensor types from a
342*14675a02SAndroid Build Coastguard Worker  # list of values.
343*14675a02SAndroid Build Coastguard Worker  def remove_federated_types(
344*14675a02SAndroid Build Coastguard Worker      type_spec: tff.Type,
345*14675a02SAndroid Build Coastguard Worker  ) -> Union[tff.StructType, tff.TensorType]:
346*14675a02SAndroid Build Coastguard Worker    """Removes `FederatedType` from a type tree, returning a new tree."""
347*14675a02SAndroid Build Coastguard Worker    if type_spec.is_tensor():
348*14675a02SAndroid Build Coastguard Worker      return type_spec
349*14675a02SAndroid Build Coastguard Worker    elif type_spec.is_federated():
350*14675a02SAndroid Build Coastguard Worker      return type_spec.member
351*14675a02SAndroid Build Coastguard Worker    elif type_spec.is_struct():
352*14675a02SAndroid Build Coastguard Worker      return tff.StructType(
353*14675a02SAndroid Build Coastguard Worker          (elem_name, remove_federated_types(elem_type))
354*14675a02SAndroid Build Coastguard Worker          for elem_name, elem_type in tff.structure.iter_elements(type_spec)
355*14675a02SAndroid Build Coastguard Worker      )
356*14675a02SAndroid Build Coastguard Worker    else:
357*14675a02SAndroid Build Coastguard Worker      raise ValueError(
358*14675a02SAndroid Build Coastguard Worker          'Must be either tff.TensorType, tff.FederatedType, or tff.StructType.'
359*14675a02SAndroid Build Coastguard Worker          f' Got a {type(type_spec)}'
360*14675a02SAndroid Build Coastguard Worker      )
361*14675a02SAndroid Build Coastguard Worker
362*14675a02SAndroid Build Coastguard Worker  try:
363*14675a02SAndroid Build Coastguard Worker    tff_type = remove_federated_types(tff_type)
364*14675a02SAndroid Build Coastguard Worker  except ValueError as e:
365*14675a02SAndroid Build Coastguard Worker    raise ValueError(
366*14675a02SAndroid Build Coastguard Worker        '`tff_type` is not packable, see earlier error. '
367*14675a02SAndroid Build Coastguard Worker        f'Attempted to pack type: {tff_type}'
368*14675a02SAndroid Build Coastguard Worker    ) from e
369*14675a02SAndroid Build Coastguard Worker
370*14675a02SAndroid Build Coastguard Worker  ordered_dtypes = tff_type_to_dtype_list(tff_type)
371*14675a02SAndroid Build Coastguard Worker  if len(ordered_dtypes) != len(value_list):
372*14675a02SAndroid Build Coastguard Worker    raise ValueError(
373*14675a02SAndroid Build Coastguard Worker        'The number of leaves in `tff_type` must equals the length'
374*14675a02SAndroid Build Coastguard Worker        ' of `value_list`. Found `tff_type` with'
375*14675a02SAndroid Build Coastguard Worker        f' {len(ordered_dtypes)} leaves and `value_list` of length'
376*14675a02SAndroid Build Coastguard Worker        f' {len(value_list)}.'
377*14675a02SAndroid Build Coastguard Worker    )
378*14675a02SAndroid Build Coastguard Worker
379*14675a02SAndroid Build Coastguard Worker  if tff_type.is_tensor():
380*14675a02SAndroid Build Coastguard Worker    return value_list[0]
381*14675a02SAndroid Build Coastguard Worker  elif tff_type.is_struct():
382*14675a02SAndroid Build Coastguard Worker    return tff.structure.pack_sequence_as(tff_type, value_list)
383*14675a02SAndroid Build Coastguard Worker  else:
384*14675a02SAndroid Build Coastguard Worker    raise ValueError(
385*14675a02SAndroid Build Coastguard Worker        '`tff_type` must be either tff.TensorType or '
386*14675a02SAndroid Build Coastguard Worker        'tff.StructType, reaching here is an internal coding '
387*14675a02SAndroid Build Coastguard Worker        'error, please file a bug.'
388*14675a02SAndroid Build Coastguard Worker    )
389*14675a02SAndroid Build Coastguard Worker
390*14675a02SAndroid Build Coastguard Worker
391*14675a02SAndroid Build Coastguard Workerdef variable_names_from_structure(
392*14675a02SAndroid Build Coastguard Worker    tff_structure: Union[tff.structure.Struct, tf.Tensor], name: str = 'v'
393*14675a02SAndroid Build Coastguard Worker) -> list[str]:
394*14675a02SAndroid Build Coastguard Worker  """Creates a flattened list of variable names for the given structure.
395*14675a02SAndroid Build Coastguard Worker
396*14675a02SAndroid Build Coastguard Worker  If the `tff_structure` is a `tf.Tensor`, the name is the `name` parameter if
397*14675a02SAndroid Build Coastguard Worker  specified, otheriwse a default name: `v`. If `tff_structure` is a
398*14675a02SAndroid Build Coastguard Worker  `tff.structure.Struct` then '/' is used between inner and outer fields
399*14675a02SAndroid Build Coastguard Worker  together with the tuple name or index of the element in the tuple.
400*14675a02SAndroid Build Coastguard Worker
401*14675a02SAndroid Build Coastguard Worker  Some examples:
402*14675a02SAndroid Build Coastguard Worker  1. If the `tff_structure` is `<'a'=tf.constant(1.0), 'b'=tf.constant(0.0)>`
403*14675a02SAndroid Build Coastguard Worker     and name is not specified, the returned variable name list is
404*14675a02SAndroid Build Coastguard Worker     ['v/a', 'v/b'].
405*14675a02SAndroid Build Coastguard Worker  2. If the `tff_structure` is `<None=tf.constant(1.0), None=tf.constant(0.0)>`
406*14675a02SAndroid Build Coastguard Worker     and `name` is `update`, the returned variable name list is
407*14675a02SAndroid Build Coastguard Worker     ['update/0', 'update/1'].
408*14675a02SAndroid Build Coastguard Worker  3. If the `tff_structure` is
409*14675a02SAndroid Build Coastguard Worker     `<'a'=<'b'=tf.constant(1.0), 'c'=tf.constant(0.0)>>` and `name` is
410*14675a02SAndroid Build Coastguard Worker     `update`, the returned variable name list is ['update/a/b', 'update/a/c'].
411*14675a02SAndroid Build Coastguard Worker  4. If the `tff_structure` is
412*14675a02SAndroid Build Coastguard Worker     `<'a'=<'b'=tf.constant(1.0), 'c'=tf.constant(1.0), tf.constant(0.0)>>` and
413*14675a02SAndroid Build Coastguard Worker     `name` is `update`, the returned variable name list is ['update/a/b',
414*14675a02SAndroid Build Coastguard Worker    'update/a/c', 'update/a/2'].
415*14675a02SAndroid Build Coastguard Worker
416*14675a02SAndroid Build Coastguard Worker  Args:
417*14675a02SAndroid Build Coastguard Worker    tff_structure: Either a `tff.structure.Struct` or a `tf.Tensor` object.
418*14675a02SAndroid Build Coastguard Worker    name: The preferred name to use at the top-most level (if not None, must be
419*14675a02SAndroid Build Coastguard Worker      a string). If `tff_structure` is a `tff.structure.Struct`, the names of
420*14675a02SAndroid Build Coastguard Worker      the inner fields will be scoped under `name`, e.g. `some_name/field_name`.
421*14675a02SAndroid Build Coastguard Worker
422*14675a02SAndroid Build Coastguard Worker  Returns:
423*14675a02SAndroid Build Coastguard Worker    A flat Python `list` of `str` names.
424*14675a02SAndroid Build Coastguard Worker
425*14675a02SAndroid Build Coastguard Worker  Raises:
426*14675a02SAndroid Build Coastguard Worker    TypeError: If either argument is of the wrong type.
427*14675a02SAndroid Build Coastguard Worker  """
428*14675a02SAndroid Build Coastguard Worker  type_checks.check_type(
429*14675a02SAndroid Build Coastguard Worker      tff_structure, (tff.structure.Struct, tf.Tensor), name='structure_type'
430*14675a02SAndroid Build Coastguard Worker  )
431*14675a02SAndroid Build Coastguard Worker  type_checks.check_type(name, str, name='name')
432*14675a02SAndroid Build Coastguard Worker  if isinstance(tff_structure, tf.Tensor):
433*14675a02SAndroid Build Coastguard Worker    return [name]
434*14675a02SAndroid Build Coastguard Worker  elif isinstance(tff_structure, tff.structure.Struct):
435*14675a02SAndroid Build Coastguard Worker    result = []
436*14675a02SAndroid Build Coastguard Worker    fields = tff.structure.iter_elements(tff_structure)
437*14675a02SAndroid Build Coastguard Worker    for index, (field_name, field_type) in enumerate(fields):
438*14675a02SAndroid Build Coastguard Worker      # Default the name of the element to its index so that we don't wind up
439*14675a02SAndroid Build Coastguard Worker      # with multiple child fields listed under `/v/`
440*14675a02SAndroid Build Coastguard Worker      field_name = field_name or str(index)
441*14675a02SAndroid Build Coastguard Worker      result.extend(
442*14675a02SAndroid Build Coastguard Worker          variable_names_from_structure(
443*14675a02SAndroid Build Coastguard Worker              field_type, name=name + '/' + field_name
444*14675a02SAndroid Build Coastguard Worker          )
445*14675a02SAndroid Build Coastguard Worker      )
446*14675a02SAndroid Build Coastguard Worker    return result
447*14675a02SAndroid Build Coastguard Worker  else:
448*14675a02SAndroid Build Coastguard Worker    raise TypeError(
449*14675a02SAndroid Build Coastguard Worker        'Cannot create variable names from [{t}] type. Short-hand: {s}'.format(
450*14675a02SAndroid Build Coastguard Worker            t=type(tff_structure), s=tff_structure
451*14675a02SAndroid Build Coastguard Worker        )
452*14675a02SAndroid Build Coastguard Worker    )
453*14675a02SAndroid Build Coastguard Worker
454*14675a02SAndroid Build Coastguard Worker
455*14675a02SAndroid Build Coastguard Workerdef is_structure_of_allowed_types(
456*14675a02SAndroid Build Coastguard Worker    structure: Union[
457*14675a02SAndroid Build Coastguard Worker        tff.structure.Struct,
458*14675a02SAndroid Build Coastguard Worker        tf.Tensor,
459*14675a02SAndroid Build Coastguard Worker        np.ndarray,
460*14675a02SAndroid Build Coastguard Worker        np.number,
461*14675a02SAndroid Build Coastguard Worker        int,
462*14675a02SAndroid Build Coastguard Worker        float,
463*14675a02SAndroid Build Coastguard Worker        str,
464*14675a02SAndroid Build Coastguard Worker        bytes,
465*14675a02SAndroid Build Coastguard Worker    ]
466*14675a02SAndroid Build Coastguard Worker) -> bool:
467*14675a02SAndroid Build Coastguard Worker  """Checks if each node in `structure` is an allowed type for serialization."""
468*14675a02SAndroid Build Coastguard Worker  flattened_structure = tff.structure.flatten(structure)
469*14675a02SAndroid Build Coastguard Worker  for item in flattened_structure:
470*14675a02SAndroid Build Coastguard Worker    if not (
471*14675a02SAndroid Build Coastguard Worker        tf.is_tensor(item)
472*14675a02SAndroid Build Coastguard Worker        or isinstance(item, (np.ndarray, np.number, int, float, str, bytes))
473*14675a02SAndroid Build Coastguard Worker    ):
474*14675a02SAndroid Build Coastguard Worker      return False
475*14675a02SAndroid Build Coastguard Worker  return True
476*14675a02SAndroid Build Coastguard Worker
477*14675a02SAndroid Build Coastguard Worker
478*14675a02SAndroid Build Coastguard Workerdef save_tff_structure_to_checkpoint(
479*14675a02SAndroid Build Coastguard Worker    tff_structure: Union[tff.structure.Struct, tf.Tensor],
480*14675a02SAndroid Build Coastguard Worker    ordered_var_names: list[str],
481*14675a02SAndroid Build Coastguard Worker    output_checkpoint_path: str,
482*14675a02SAndroid Build Coastguard Worker) -> None:
483*14675a02SAndroid Build Coastguard Worker  """Saves a TFF structure to a checkpoint file.
484*14675a02SAndroid Build Coastguard Worker
485*14675a02SAndroid Build Coastguard Worker  The input `tff_structure` is a either `tff.structure.Struct` or a single
486*14675a02SAndroid Build Coastguard Worker  `tf.Tensor`. This function saves `tff_structure` to a checkpoint file using
487*14675a02SAndroid Build Coastguard Worker  variable names supplied via the `ordered_var_names` argument.
488*14675a02SAndroid Build Coastguard Worker
489*14675a02SAndroid Build Coastguard Worker  Args:
490*14675a02SAndroid Build Coastguard Worker    tff_structure: A `tff.structure.Struct` of values or a single value. Each
491*14675a02SAndroid Build Coastguard Worker      leaf in the structure must be a value serializable to a TensorFlow
492*14675a02SAndroid Build Coastguard Worker      checkpoint.
493*14675a02SAndroid Build Coastguard Worker    ordered_var_names: The list of variable names for the values that appear in
494*14675a02SAndroid Build Coastguard Worker      `tff_structure` after calling `tff.structure.flatten()`.
495*14675a02SAndroid Build Coastguard Worker    output_checkpoint_path: A string specifying the path to the output
496*14675a02SAndroid Build Coastguard Worker      checkpoint file.
497*14675a02SAndroid Build Coastguard Worker
498*14675a02SAndroid Build Coastguard Worker  Raises:
499*14675a02SAndroid Build Coastguard Worker    TypeError: If not all leaves in `tff_structure` are of allowed types.
500*14675a02SAndroid Build Coastguard Worker    ValueError: If the number of `tf.Tensor`s in `tff_structure` does not match
501*14675a02SAndroid Build Coastguard Worker      the size of `ordered_var_names`.
502*14675a02SAndroid Build Coastguard Worker  """
503*14675a02SAndroid Build Coastguard Worker  if not is_structure_of_allowed_types(tff_structure):
504*14675a02SAndroid Build Coastguard Worker    raise TypeError(
505*14675a02SAndroid Build Coastguard Worker        'Not all leaves in `tff_structure` are `tf.Tensor`s, '
506*14675a02SAndroid Build Coastguard Worker        '`np.ndarray`s, `np.number`s, or Python scalars. Got: '
507*14675a02SAndroid Build Coastguard Worker        f'{tff.structure.map_structure(type, tff_structure)!r})'
508*14675a02SAndroid Build Coastguard Worker    )
509*14675a02SAndroid Build Coastguard Worker
510*14675a02SAndroid Build Coastguard Worker  tensors = tff.structure.flatten(tff_structure)
511*14675a02SAndroid Build Coastguard Worker  if len(tensors) != len(ordered_var_names):
512*14675a02SAndroid Build Coastguard Worker    raise ValueError(
513*14675a02SAndroid Build Coastguard Worker        'The length of `ordered_var_names` does not match the '
514*14675a02SAndroid Build Coastguard Worker        'number of tensors in `tff_structure`:'
515*14675a02SAndroid Build Coastguard Worker        f'{len(ordered_var_names)} != {len(tensors)}'
516*14675a02SAndroid Build Coastguard Worker    )
517*14675a02SAndroid Build Coastguard Worker
518*14675a02SAndroid Build Coastguard Worker  tensor_utils.save(
519*14675a02SAndroid Build Coastguard Worker      output_checkpoint_path, tensor_names=ordered_var_names, tensors=tensors
520*14675a02SAndroid Build Coastguard Worker  )
521