xref: /aosp_15_r20/external/federated-compute/fcp/artifact_building/checkpoint_utils.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Helper methods for working with demo server checkpoints."""
15
16import collections
17from collections.abc import Callable, Iterable, Mapping
18from typing import Any, Optional, Union
19
20import numpy as np
21import tensorflow as tf
22import tensorflow_federated as tff
23
24from fcp.artifact_building import artifact_constants
25from fcp.artifact_building import tensor_utils
26from fcp.artifact_building import type_checks
27from fcp.artifact_building import variable_helpers
28from fcp.protos import plan_pb2
29
30SAVE_SERVER_SAVEPOINT_NAME = 'save_server_savepoint'
31
32
33def create_server_checkpoint_vars_and_savepoint(
34    *,
35    server_state_type: tff.StructType,
36    server_metrics_type: Optional[tff.StructType] = None,
37    write_metrics_to_checkpoint: bool = True,
38    additional_checkpoint_metadata_var_fn: Optional[
39        Callable[[tff.StructType, tff.StructType, bool], list[tf.Variable]]
40    ] = None,
41) -> tuple[
42    list[tf.Variable],
43    list[tf.Variable],
44    list[tf.Variable],
45    plan_pb2.CheckpointOp,
46]:
47  """Creates tf.Variables for a server checkpoint and the associated savepoint.
48
49  The variables and the associated saver are constructed in the default graph.
50
51  For now, only `server_state_type` is required. If metrics are to be saved in
52  the server checkpoint, `server_metrics_type` and `server_result_type` must
53  be provided. `server_state_type` refers to the server state portion of the
54  checkpoint and is used in the `Restore` op of the savepoint. The
55  `server_metrics_type` refers to the metrics saved in the checkpoint, and is
56  not used in the `Restore` op of the savepoint. `server_result_type` refers to
57  the complete round result structure stored in the checkpoint for a round.
58
59  Args:
60    server_state_type: A `tff.Type` with the type signature of the state. This
61      is used to construct the server state variable names stored in the
62      checkpoint and is used to create the metadata variables for the checkpoint
63      if `server_result_type` is not provided.
64    server_metrics_type: Optional. A `tff.Type` with the type signature of the
65      metrics. If provided, this is used to construct the metric variable names
66      that are stored in the checkpoint.
67    write_metrics_to_checkpoint: If False, revert to legacy behavior where
68      metrics and other non-state values were handled by post-processing
69      separate from the outputted checkpoint.
70    additional_checkpoint_metadata_var_fn: An optional method that takes in the
71      server_state_type, server_metrics_type, and write_metrics_to_checkpoint to
72      produce additional metadata variables.
73
74  Returns:
75    A tuple `(state_vars, metric_vars, metadata_vars, savepoint)`:
76    - `state_vars` is a Python `list` of variables that hold the state.
77    - `metric_vars` is a Python `list` of variables that hold the metrics.
78    - `metadata_vars` is a Python `list` of variables that hold optional
79      metadata.
80    - `savepoint` is the associated savepoint, i.e., an instance of
81      `plan_pb2.CheckpointOp` with a saver configured for saving the
82      `state_vars`, `metadata_vars`, and, if write_metrics_to_checkpoint is
83      True, `metric_vars`, and restoring the `state_vars` and
84      `metadata_vars`.
85  """
86  has_metrics = False
87  metric_vars = []
88  save_tensor_name = None
89  type_checks.check_type(server_state_type, tff.Type, name='server_state_type')
90  state_vars = variable_helpers.create_vars_for_tff_type(
91      server_state_type, artifact_constants.SERVER_STATE_VAR_PREFIX
92  )
93  var_names = list(map(tensor_utils.bare_name, state_vars))
94  metadata_vars = []
95  if server_metrics_type is not None:
96    type_checks.check_type(
97        server_metrics_type, tff.Type, name='server_metrics_type'
98    )
99    metric_vars = variable_helpers.create_vars_for_tff_type(
100        server_metrics_type, artifact_constants.SERVER_METRICS_VAR_PREFIX
101    )
102    if additional_checkpoint_metadata_var_fn:
103      metadata_vars = additional_checkpoint_metadata_var_fn(
104          state_vars, metric_vars, write_metrics_to_checkpoint
105      )
106
107    has_metrics = bool(tff.structure.flatten(server_metrics_type))
108    if has_metrics and write_metrics_to_checkpoint:
109      var_names.extend(list(map(tensor_utils.bare_name, metric_vars)))
110
111      temp_saver_for_all_vars = create_deterministic_saver(
112          var_list=state_vars + metadata_vars + metric_vars,
113          name=SAVE_SERVER_SAVEPOINT_NAME,
114      )
115      temp_saver_def = temp_saver_for_all_vars.as_saver_def()
116      save_tensor_name = temp_saver_def.save_tensor_name
117  else:
118    if additional_checkpoint_metadata_var_fn:
119      metadata_vars = additional_checkpoint_metadata_var_fn(
120          state_vars, None, write_metrics_to_checkpoint
121      )
122
123  saver = create_deterministic_saver(
124      var_list=state_vars + metadata_vars,
125      name='{}_savepoint'.format(artifact_constants.SERVER_STATE_VAR_PREFIX),
126  )
127  savepoint = plan_pb2.CheckpointOp()
128  savepoint.saver_def.CopyFrom(saver.as_saver_def())
129
130  if save_tensor_name is not None:
131    # Replace the save_tensor_name to the one in
132    # temp_saver_for_all_vars so that we are additionally saving metrics vars
133    # in the checkpoint that don't need to be restored as part of the input
134    # computation state.
135    # Once we create the server GraphDef, we will edit the GraphDef directly
136    # to ensure the input filename links to the filename tensor from the
137    # `savepoint`.
138    savepoint.saver_def.save_tensor_name = save_tensor_name
139  return state_vars, metric_vars, metadata_vars, savepoint
140
141
142def create_state_vars_and_savepoint(
143    type_spec: variable_helpers.AllowedTffTypes, name: str
144) -> tuple[list[tf.Variable], plan_pb2.CheckpointOp]:
145  """Creates state variables and their savepoint as a `plan_pb2.CheckpointOp`.
146
147  The variables and the associated saver are constructed in the default graph.
148
149  Args:
150    type_spec: An instance of `tff.Type` with the type signature of the state.
151    name: The string to use as a basis for naming the vars and the saver. The
152      vars will be under `${name}_state`, and saver under `${name}_savepoint`.
153
154  Returns:
155    A tuple `(vars, savepoint)`, where `vars` is a Python `list` of variables
156    that hold the state, and `savepoint` is the associated savepoint, i.e.,
157    an instance of `plan_pb2.CheckpointOp` with a saver configured for saving
158    and restoring the `vars`.
159
160  Raises:
161    ValueError: If the name is empty.
162  """
163  state_vars, saver = create_state_vars_and_saver(type_spec, name)
164  savepoint = plan_pb2.CheckpointOp()
165  savepoint.saver_def.CopyFrom(saver.as_saver_def())
166  return state_vars, savepoint
167
168
169def create_state_vars_and_saver(
170    type_spec: variable_helpers.AllowedTffTypes, name: str
171) -> tuple[list[tf.Variable], tf.compat.v1.train.Saver]:
172  """Creates state variables and the associated saver.
173
174  The variables and the associated saver are constructed in the default graph.
175
176  Args:
177    type_spec: An instance of `tff.Type` with the type signature of the state.
178    name: The string to use as a basis for naming the vars and the saver. The
179      vars will be under `${name}_state`, and saver under `${name}_savepoint`.
180
181  Returns:
182    A tuple `(vars, savepoint)`, where `vars` is a Python `list` of variables
183    that hold the state, and `savepoint` is the associated
184    `tf.compat.v1.train.Saver`.
185
186  Raises:
187    ValueError: If the name is empty.
188  """
189  type_checks.check_type(type_spec, tff.Type, name='type_spec')
190  type_checks.check_type(name, str, name='name')
191  if not name:
192    raise ValueError('Name cannot be empty.')
193  state_vars = variable_helpers.create_vars_for_tff_type(type_spec, name)
194  saver = create_deterministic_saver(
195      state_vars, name='{}_savepoint'.format(name)
196  )
197  return state_vars, saver
198
199
200def restore_tensors_from_savepoint(
201    tensor_specs: Iterable[tf.TensorSpec], filepath_tensor: tf.Tensor
202) -> list[tf.Tensor]:
203  """Restores tensors from a checkpoint designated by a tensor filepath.
204
205  Args:
206    tensor_specs: A `list` of `tf.TensorSpec`s with the names and dtypes of the
207      tensors to restore.
208    filepath_tensor: A placeholder tensor that contains file names with a given
209      pattern.
210
211  Returns:
212    A list of restored tensors.
213  """
214  return [
215      tensor_utils.restore(
216          filepath_tensor, tensor_utils.bare_name(spec.name), spec.dtype
217      )
218      for spec in tensor_specs
219  ]
220
221
222def create_deterministic_saver(
223    var_list: Union[Iterable[tf.Variable], Mapping[str, tf.Variable]],
224    *args,
225    **kwargs,
226) -> tf.compat.v1.train.Saver:
227  """Creates a `tf.compat.v1.Saver` that is deterministic.
228
229  This method sorts the `var_list` to ensure a deterministic ordering which
230  in turn ensures a deterministic checkpoint.
231
232  Uses `tf.compat.v1.train.SaverDef.V1` version for writing checkpoints.
233
234  Args:
235    var_list: An `Iterable` or `str` keyed `Mapping` of `tf.Variables`. In the
236      case of a `dict`, the keys become the names of the checkpoint variables
237      (rather than reading the names off the `tf.Variable` values).
238    *args: Positional arguments forwarded to the `tf.compat.v1.train.Saver`
239      constructor.
240    **kwargs: Keyword arguments forwarded to the `tf.compat.v1.train.Saver`
241      constructor.
242
243  Returns:
244    A `tf.compat.v1.train.Saver` instance.
245  """
246  if isinstance(var_list, collections.abc.Mapping):
247    determinisic_names = collections.OrderedDict(sorted(var_list.items()))
248  elif isinstance(var_list, collections.abc.Iterable):
249    determinisic_names = sorted(var_list, key=lambda v: v.name)
250  else:
251    raise ValueError(
252        'Do not know how to make a deterministic saver for '
253        '`var_list` of type [{t}]. Must be a Mapping or Sequence'.format(
254            t=type(var_list)
255        )
256    )
257  return tf.compat.v1.train.Saver(
258      determinisic_names,
259      write_version=tf.compat.v1.train.SaverDef.V1,
260      *args,
261      **kwargs,
262  )
263
264
265def tff_type_to_dtype_list(
266    tff_type: variable_helpers.AllowedTffTypes,
267) -> list[tf.DType]:
268  """Creates a flat list of `tf.DType`s for tensors in a `tff.Type`.
269
270  Args:
271    tff_type: Either a `tff.StructType`, `tff.FederatedType`, or a
272      `tff.TensorType` object.
273
274  Returns:
275    A flat list of `tf.DType`s.
276  """
277  type_checks.check_type(
278      tff_type, (tff.TensorType, tff.FederatedType, tff.StructType)
279  )
280  if isinstance(tff_type, tff.TensorType):
281    return [tff_type.dtype]
282  elif isinstance(tff_type, tff.FederatedType):
283    return tff_type_to_dtype_list(tff_type.member)
284  else:  # tff.StructType
285    elem_list = []
286    for elem_type in tff_type:
287      elem_list.extend(tff_type_to_dtype_list(elem_type))
288    return elem_list
289
290
291def tff_type_to_tensor_spec_list(
292    tff_type: variable_helpers.AllowedTffTypes,
293) -> list[tf.TensorSpec]:
294  """Creates a flat list of tensor specs for tensors in a `tff.Type`.
295
296  Args:
297    tff_type: Either a `tff.StructType`, `tff.FederatedType` or a
298      `tff.TensorType` object.
299
300  Returns:
301    A flat list of `tf.TensorSpec`s.
302  """
303  type_checks.check_type(
304      tff_type, (tff.TensorType, tff.FederatedType, tff.StructType)
305  )
306  if isinstance(tff_type, tff.TensorType):
307    return [tf.TensorSpec(tff_type.shape, dtype=tff_type.dtype)]
308  elif isinstance(tff_type, tff.FederatedType):
309    return tff_type_to_tensor_spec_list(tff_type.member)
310  else:  # tff.StructType
311    elem_list = []
312    for elem_type in tff_type:
313      elem_list.extend(tff_type_to_tensor_spec_list(elem_type))
314    return elem_list
315
316
317def pack_tff_value(
318    tff_type: variable_helpers.AllowedTffTypes, value_list: Any
319) -> Any:
320  """Packs a list of values into a shape specified by a `tff.Type`.
321
322  Args:
323    tff_type: Either a `tff.StructType`, `tff.FederatedType`, or a
324      `tff.TensorType` object.
325    value_list: A flat list of `tf.Tensor` or `CheckpointTensorReference`.
326
327  Returns:
328    A Python container with a structure consistent with a `tff.Type`.
329
330  Raises:
331    ValueError: If the number of leaves in `tff_type` does not match the length
332    of `value_list`, or `tff_type` is of a disallowed type.
333  """
334  type_checks.check_type(
335      tff_type, (tff.TensorType, tff.FederatedType, tff.StructType)
336  )
337
338  # We must "unwrap" any FederatedTypes because the
339  # `tff.structure.pack_sequence_as` call below will fail to recurse into them.
340  # Instead, we remove all the FederatedTypes, because we're only trying to
341  # build up a Python tree structure that matches the struct/tensor types from a
342  # list of values.
343  def remove_federated_types(
344      type_spec: tff.Type,
345  ) -> Union[tff.StructType, tff.TensorType]:
346    """Removes `FederatedType` from a type tree, returning a new tree."""
347    if type_spec.is_tensor():
348      return type_spec
349    elif type_spec.is_federated():
350      return type_spec.member
351    elif type_spec.is_struct():
352      return tff.StructType(
353          (elem_name, remove_federated_types(elem_type))
354          for elem_name, elem_type in tff.structure.iter_elements(type_spec)
355      )
356    else:
357      raise ValueError(
358          'Must be either tff.TensorType, tff.FederatedType, or tff.StructType.'
359          f' Got a {type(type_spec)}'
360      )
361
362  try:
363    tff_type = remove_federated_types(tff_type)
364  except ValueError as e:
365    raise ValueError(
366        '`tff_type` is not packable, see earlier error. '
367        f'Attempted to pack type: {tff_type}'
368    ) from e
369
370  ordered_dtypes = tff_type_to_dtype_list(tff_type)
371  if len(ordered_dtypes) != len(value_list):
372    raise ValueError(
373        'The number of leaves in `tff_type` must equals the length'
374        ' of `value_list`. Found `tff_type` with'
375        f' {len(ordered_dtypes)} leaves and `value_list` of length'
376        f' {len(value_list)}.'
377    )
378
379  if tff_type.is_tensor():
380    return value_list[0]
381  elif tff_type.is_struct():
382    return tff.structure.pack_sequence_as(tff_type, value_list)
383  else:
384    raise ValueError(
385        '`tff_type` must be either tff.TensorType or '
386        'tff.StructType, reaching here is an internal coding '
387        'error, please file a bug.'
388    )
389
390
391def variable_names_from_structure(
392    tff_structure: Union[tff.structure.Struct, tf.Tensor], name: str = 'v'
393) -> list[str]:
394  """Creates a flattened list of variable names for the given structure.
395
396  If the `tff_structure` is a `tf.Tensor`, the name is the `name` parameter if
397  specified, otheriwse a default name: `v`. If `tff_structure` is a
398  `tff.structure.Struct` then '/' is used between inner and outer fields
399  together with the tuple name or index of the element in the tuple.
400
401  Some examples:
402  1. If the `tff_structure` is `<'a'=tf.constant(1.0), 'b'=tf.constant(0.0)>`
403     and name is not specified, the returned variable name list is
404     ['v/a', 'v/b'].
405  2. If the `tff_structure` is `<None=tf.constant(1.0), None=tf.constant(0.0)>`
406     and `name` is `update`, the returned variable name list is
407     ['update/0', 'update/1'].
408  3. If the `tff_structure` is
409     `<'a'=<'b'=tf.constant(1.0), 'c'=tf.constant(0.0)>>` and `name` is
410     `update`, the returned variable name list is ['update/a/b', 'update/a/c'].
411  4. If the `tff_structure` is
412     `<'a'=<'b'=tf.constant(1.0), 'c'=tf.constant(1.0), tf.constant(0.0)>>` and
413     `name` is `update`, the returned variable name list is ['update/a/b',
414    'update/a/c', 'update/a/2'].
415
416  Args:
417    tff_structure: Either a `tff.structure.Struct` or a `tf.Tensor` object.
418    name: The preferred name to use at the top-most level (if not None, must be
419      a string). If `tff_structure` is a `tff.structure.Struct`, the names of
420      the inner fields will be scoped under `name`, e.g. `some_name/field_name`.
421
422  Returns:
423    A flat Python `list` of `str` names.
424
425  Raises:
426    TypeError: If either argument is of the wrong type.
427  """
428  type_checks.check_type(
429      tff_structure, (tff.structure.Struct, tf.Tensor), name='structure_type'
430  )
431  type_checks.check_type(name, str, name='name')
432  if isinstance(tff_structure, tf.Tensor):
433    return [name]
434  elif isinstance(tff_structure, tff.structure.Struct):
435    result = []
436    fields = tff.structure.iter_elements(tff_structure)
437    for index, (field_name, field_type) in enumerate(fields):
438      # Default the name of the element to its index so that we don't wind up
439      # with multiple child fields listed under `/v/`
440      field_name = field_name or str(index)
441      result.extend(
442          variable_names_from_structure(
443              field_type, name=name + '/' + field_name
444          )
445      )
446    return result
447  else:
448    raise TypeError(
449        'Cannot create variable names from [{t}] type. Short-hand: {s}'.format(
450            t=type(tff_structure), s=tff_structure
451        )
452    )
453
454
455def is_structure_of_allowed_types(
456    structure: Union[
457        tff.structure.Struct,
458        tf.Tensor,
459        np.ndarray,
460        np.number,
461        int,
462        float,
463        str,
464        bytes,
465    ]
466) -> bool:
467  """Checks if each node in `structure` is an allowed type for serialization."""
468  flattened_structure = tff.structure.flatten(structure)
469  for item in flattened_structure:
470    if not (
471        tf.is_tensor(item)
472        or isinstance(item, (np.ndarray, np.number, int, float, str, bytes))
473    ):
474      return False
475  return True
476
477
478def save_tff_structure_to_checkpoint(
479    tff_structure: Union[tff.structure.Struct, tf.Tensor],
480    ordered_var_names: list[str],
481    output_checkpoint_path: str,
482) -> None:
483  """Saves a TFF structure to a checkpoint file.
484
485  The input `tff_structure` is a either `tff.structure.Struct` or a single
486  `tf.Tensor`. This function saves `tff_structure` to a checkpoint file using
487  variable names supplied via the `ordered_var_names` argument.
488
489  Args:
490    tff_structure: A `tff.structure.Struct` of values or a single value. Each
491      leaf in the structure must be a value serializable to a TensorFlow
492      checkpoint.
493    ordered_var_names: The list of variable names for the values that appear in
494      `tff_structure` after calling `tff.structure.flatten()`.
495    output_checkpoint_path: A string specifying the path to the output
496      checkpoint file.
497
498  Raises:
499    TypeError: If not all leaves in `tff_structure` are of allowed types.
500    ValueError: If the number of `tf.Tensor`s in `tff_structure` does not match
501      the size of `ordered_var_names`.
502  """
503  if not is_structure_of_allowed_types(tff_structure):
504    raise TypeError(
505        'Not all leaves in `tff_structure` are `tf.Tensor`s, '
506        '`np.ndarray`s, `np.number`s, or Python scalars. Got: '
507        f'{tff.structure.map_structure(type, tff_structure)!r})'
508    )
509
510  tensors = tff.structure.flatten(tff_structure)
511  if len(tensors) != len(ordered_var_names):
512    raise ValueError(
513        'The length of `ordered_var_names` does not match the '
514        'number of tensors in `tff_structure`:'
515        f'{len(ordered_var_names)} != {len(tensors)}'
516    )
517
518  tensor_utils.save(
519      output_checkpoint_path, tensor_names=ordered_var_names, tensors=tensors
520  )
521