1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Helper methods for working with demo server checkpoints.""" 15 16import collections 17from collections.abc import Callable, Iterable, Mapping 18from typing import Any, Optional, Union 19 20import numpy as np 21import tensorflow as tf 22import tensorflow_federated as tff 23 24from fcp.artifact_building import artifact_constants 25from fcp.artifact_building import tensor_utils 26from fcp.artifact_building import type_checks 27from fcp.artifact_building import variable_helpers 28from fcp.protos import plan_pb2 29 30SAVE_SERVER_SAVEPOINT_NAME = 'save_server_savepoint' 31 32 33def create_server_checkpoint_vars_and_savepoint( 34 *, 35 server_state_type: tff.StructType, 36 server_metrics_type: Optional[tff.StructType] = None, 37 write_metrics_to_checkpoint: bool = True, 38 additional_checkpoint_metadata_var_fn: Optional[ 39 Callable[[tff.StructType, tff.StructType, bool], list[tf.Variable]] 40 ] = None, 41) -> tuple[ 42 list[tf.Variable], 43 list[tf.Variable], 44 list[tf.Variable], 45 plan_pb2.CheckpointOp, 46]: 47 """Creates tf.Variables for a server checkpoint and the associated savepoint. 48 49 The variables and the associated saver are constructed in the default graph. 50 51 For now, only `server_state_type` is required. If metrics are to be saved in 52 the server checkpoint, `server_metrics_type` and `server_result_type` must 53 be provided. `server_state_type` refers to the server state portion of the 54 checkpoint and is used in the `Restore` op of the savepoint. The 55 `server_metrics_type` refers to the metrics saved in the checkpoint, and is 56 not used in the `Restore` op of the savepoint. `server_result_type` refers to 57 the complete round result structure stored in the checkpoint for a round. 58 59 Args: 60 server_state_type: A `tff.Type` with the type signature of the state. This 61 is used to construct the server state variable names stored in the 62 checkpoint and is used to create the metadata variables for the checkpoint 63 if `server_result_type` is not provided. 64 server_metrics_type: Optional. A `tff.Type` with the type signature of the 65 metrics. If provided, this is used to construct the metric variable names 66 that are stored in the checkpoint. 67 write_metrics_to_checkpoint: If False, revert to legacy behavior where 68 metrics and other non-state values were handled by post-processing 69 separate from the outputted checkpoint. 70 additional_checkpoint_metadata_var_fn: An optional method that takes in the 71 server_state_type, server_metrics_type, and write_metrics_to_checkpoint to 72 produce additional metadata variables. 73 74 Returns: 75 A tuple `(state_vars, metric_vars, metadata_vars, savepoint)`: 76 - `state_vars` is a Python `list` of variables that hold the state. 77 - `metric_vars` is a Python `list` of variables that hold the metrics. 78 - `metadata_vars` is a Python `list` of variables that hold optional 79 metadata. 80 - `savepoint` is the associated savepoint, i.e., an instance of 81 `plan_pb2.CheckpointOp` with a saver configured for saving the 82 `state_vars`, `metadata_vars`, and, if write_metrics_to_checkpoint is 83 True, `metric_vars`, and restoring the `state_vars` and 84 `metadata_vars`. 85 """ 86 has_metrics = False 87 metric_vars = [] 88 save_tensor_name = None 89 type_checks.check_type(server_state_type, tff.Type, name='server_state_type') 90 state_vars = variable_helpers.create_vars_for_tff_type( 91 server_state_type, artifact_constants.SERVER_STATE_VAR_PREFIX 92 ) 93 var_names = list(map(tensor_utils.bare_name, state_vars)) 94 metadata_vars = [] 95 if server_metrics_type is not None: 96 type_checks.check_type( 97 server_metrics_type, tff.Type, name='server_metrics_type' 98 ) 99 metric_vars = variable_helpers.create_vars_for_tff_type( 100 server_metrics_type, artifact_constants.SERVER_METRICS_VAR_PREFIX 101 ) 102 if additional_checkpoint_metadata_var_fn: 103 metadata_vars = additional_checkpoint_metadata_var_fn( 104 state_vars, metric_vars, write_metrics_to_checkpoint 105 ) 106 107 has_metrics = bool(tff.structure.flatten(server_metrics_type)) 108 if has_metrics and write_metrics_to_checkpoint: 109 var_names.extend(list(map(tensor_utils.bare_name, metric_vars))) 110 111 temp_saver_for_all_vars = create_deterministic_saver( 112 var_list=state_vars + metadata_vars + metric_vars, 113 name=SAVE_SERVER_SAVEPOINT_NAME, 114 ) 115 temp_saver_def = temp_saver_for_all_vars.as_saver_def() 116 save_tensor_name = temp_saver_def.save_tensor_name 117 else: 118 if additional_checkpoint_metadata_var_fn: 119 metadata_vars = additional_checkpoint_metadata_var_fn( 120 state_vars, None, write_metrics_to_checkpoint 121 ) 122 123 saver = create_deterministic_saver( 124 var_list=state_vars + metadata_vars, 125 name='{}_savepoint'.format(artifact_constants.SERVER_STATE_VAR_PREFIX), 126 ) 127 savepoint = plan_pb2.CheckpointOp() 128 savepoint.saver_def.CopyFrom(saver.as_saver_def()) 129 130 if save_tensor_name is not None: 131 # Replace the save_tensor_name to the one in 132 # temp_saver_for_all_vars so that we are additionally saving metrics vars 133 # in the checkpoint that don't need to be restored as part of the input 134 # computation state. 135 # Once we create the server GraphDef, we will edit the GraphDef directly 136 # to ensure the input filename links to the filename tensor from the 137 # `savepoint`. 138 savepoint.saver_def.save_tensor_name = save_tensor_name 139 return state_vars, metric_vars, metadata_vars, savepoint 140 141 142def create_state_vars_and_savepoint( 143 type_spec: variable_helpers.AllowedTffTypes, name: str 144) -> tuple[list[tf.Variable], plan_pb2.CheckpointOp]: 145 """Creates state variables and their savepoint as a `plan_pb2.CheckpointOp`. 146 147 The variables and the associated saver are constructed in the default graph. 148 149 Args: 150 type_spec: An instance of `tff.Type` with the type signature of the state. 151 name: The string to use as a basis for naming the vars and the saver. The 152 vars will be under `${name}_state`, and saver under `${name}_savepoint`. 153 154 Returns: 155 A tuple `(vars, savepoint)`, where `vars` is a Python `list` of variables 156 that hold the state, and `savepoint` is the associated savepoint, i.e., 157 an instance of `plan_pb2.CheckpointOp` with a saver configured for saving 158 and restoring the `vars`. 159 160 Raises: 161 ValueError: If the name is empty. 162 """ 163 state_vars, saver = create_state_vars_and_saver(type_spec, name) 164 savepoint = plan_pb2.CheckpointOp() 165 savepoint.saver_def.CopyFrom(saver.as_saver_def()) 166 return state_vars, savepoint 167 168 169def create_state_vars_and_saver( 170 type_spec: variable_helpers.AllowedTffTypes, name: str 171) -> tuple[list[tf.Variable], tf.compat.v1.train.Saver]: 172 """Creates state variables and the associated saver. 173 174 The variables and the associated saver are constructed in the default graph. 175 176 Args: 177 type_spec: An instance of `tff.Type` with the type signature of the state. 178 name: The string to use as a basis for naming the vars and the saver. The 179 vars will be under `${name}_state`, and saver under `${name}_savepoint`. 180 181 Returns: 182 A tuple `(vars, savepoint)`, where `vars` is a Python `list` of variables 183 that hold the state, and `savepoint` is the associated 184 `tf.compat.v1.train.Saver`. 185 186 Raises: 187 ValueError: If the name is empty. 188 """ 189 type_checks.check_type(type_spec, tff.Type, name='type_spec') 190 type_checks.check_type(name, str, name='name') 191 if not name: 192 raise ValueError('Name cannot be empty.') 193 state_vars = variable_helpers.create_vars_for_tff_type(type_spec, name) 194 saver = create_deterministic_saver( 195 state_vars, name='{}_savepoint'.format(name) 196 ) 197 return state_vars, saver 198 199 200def restore_tensors_from_savepoint( 201 tensor_specs: Iterable[tf.TensorSpec], filepath_tensor: tf.Tensor 202) -> list[tf.Tensor]: 203 """Restores tensors from a checkpoint designated by a tensor filepath. 204 205 Args: 206 tensor_specs: A `list` of `tf.TensorSpec`s with the names and dtypes of the 207 tensors to restore. 208 filepath_tensor: A placeholder tensor that contains file names with a given 209 pattern. 210 211 Returns: 212 A list of restored tensors. 213 """ 214 return [ 215 tensor_utils.restore( 216 filepath_tensor, tensor_utils.bare_name(spec.name), spec.dtype 217 ) 218 for spec in tensor_specs 219 ] 220 221 222def create_deterministic_saver( 223 var_list: Union[Iterable[tf.Variable], Mapping[str, tf.Variable]], 224 *args, 225 **kwargs, 226) -> tf.compat.v1.train.Saver: 227 """Creates a `tf.compat.v1.Saver` that is deterministic. 228 229 This method sorts the `var_list` to ensure a deterministic ordering which 230 in turn ensures a deterministic checkpoint. 231 232 Uses `tf.compat.v1.train.SaverDef.V1` version for writing checkpoints. 233 234 Args: 235 var_list: An `Iterable` or `str` keyed `Mapping` of `tf.Variables`. In the 236 case of a `dict`, the keys become the names of the checkpoint variables 237 (rather than reading the names off the `tf.Variable` values). 238 *args: Positional arguments forwarded to the `tf.compat.v1.train.Saver` 239 constructor. 240 **kwargs: Keyword arguments forwarded to the `tf.compat.v1.train.Saver` 241 constructor. 242 243 Returns: 244 A `tf.compat.v1.train.Saver` instance. 245 """ 246 if isinstance(var_list, collections.abc.Mapping): 247 determinisic_names = collections.OrderedDict(sorted(var_list.items())) 248 elif isinstance(var_list, collections.abc.Iterable): 249 determinisic_names = sorted(var_list, key=lambda v: v.name) 250 else: 251 raise ValueError( 252 'Do not know how to make a deterministic saver for ' 253 '`var_list` of type [{t}]. Must be a Mapping or Sequence'.format( 254 t=type(var_list) 255 ) 256 ) 257 return tf.compat.v1.train.Saver( 258 determinisic_names, 259 write_version=tf.compat.v1.train.SaverDef.V1, 260 *args, 261 **kwargs, 262 ) 263 264 265def tff_type_to_dtype_list( 266 tff_type: variable_helpers.AllowedTffTypes, 267) -> list[tf.DType]: 268 """Creates a flat list of `tf.DType`s for tensors in a `tff.Type`. 269 270 Args: 271 tff_type: Either a `tff.StructType`, `tff.FederatedType`, or a 272 `tff.TensorType` object. 273 274 Returns: 275 A flat list of `tf.DType`s. 276 """ 277 type_checks.check_type( 278 tff_type, (tff.TensorType, tff.FederatedType, tff.StructType) 279 ) 280 if isinstance(tff_type, tff.TensorType): 281 return [tff_type.dtype] 282 elif isinstance(tff_type, tff.FederatedType): 283 return tff_type_to_dtype_list(tff_type.member) 284 else: # tff.StructType 285 elem_list = [] 286 for elem_type in tff_type: 287 elem_list.extend(tff_type_to_dtype_list(elem_type)) 288 return elem_list 289 290 291def tff_type_to_tensor_spec_list( 292 tff_type: variable_helpers.AllowedTffTypes, 293) -> list[tf.TensorSpec]: 294 """Creates a flat list of tensor specs for tensors in a `tff.Type`. 295 296 Args: 297 tff_type: Either a `tff.StructType`, `tff.FederatedType` or a 298 `tff.TensorType` object. 299 300 Returns: 301 A flat list of `tf.TensorSpec`s. 302 """ 303 type_checks.check_type( 304 tff_type, (tff.TensorType, tff.FederatedType, tff.StructType) 305 ) 306 if isinstance(tff_type, tff.TensorType): 307 return [tf.TensorSpec(tff_type.shape, dtype=tff_type.dtype)] 308 elif isinstance(tff_type, tff.FederatedType): 309 return tff_type_to_tensor_spec_list(tff_type.member) 310 else: # tff.StructType 311 elem_list = [] 312 for elem_type in tff_type: 313 elem_list.extend(tff_type_to_tensor_spec_list(elem_type)) 314 return elem_list 315 316 317def pack_tff_value( 318 tff_type: variable_helpers.AllowedTffTypes, value_list: Any 319) -> Any: 320 """Packs a list of values into a shape specified by a `tff.Type`. 321 322 Args: 323 tff_type: Either a `tff.StructType`, `tff.FederatedType`, or a 324 `tff.TensorType` object. 325 value_list: A flat list of `tf.Tensor` or `CheckpointTensorReference`. 326 327 Returns: 328 A Python container with a structure consistent with a `tff.Type`. 329 330 Raises: 331 ValueError: If the number of leaves in `tff_type` does not match the length 332 of `value_list`, or `tff_type` is of a disallowed type. 333 """ 334 type_checks.check_type( 335 tff_type, (tff.TensorType, tff.FederatedType, tff.StructType) 336 ) 337 338 # We must "unwrap" any FederatedTypes because the 339 # `tff.structure.pack_sequence_as` call below will fail to recurse into them. 340 # Instead, we remove all the FederatedTypes, because we're only trying to 341 # build up a Python tree structure that matches the struct/tensor types from a 342 # list of values. 343 def remove_federated_types( 344 type_spec: tff.Type, 345 ) -> Union[tff.StructType, tff.TensorType]: 346 """Removes `FederatedType` from a type tree, returning a new tree.""" 347 if type_spec.is_tensor(): 348 return type_spec 349 elif type_spec.is_federated(): 350 return type_spec.member 351 elif type_spec.is_struct(): 352 return tff.StructType( 353 (elem_name, remove_federated_types(elem_type)) 354 for elem_name, elem_type in tff.structure.iter_elements(type_spec) 355 ) 356 else: 357 raise ValueError( 358 'Must be either tff.TensorType, tff.FederatedType, or tff.StructType.' 359 f' Got a {type(type_spec)}' 360 ) 361 362 try: 363 tff_type = remove_federated_types(tff_type) 364 except ValueError as e: 365 raise ValueError( 366 '`tff_type` is not packable, see earlier error. ' 367 f'Attempted to pack type: {tff_type}' 368 ) from e 369 370 ordered_dtypes = tff_type_to_dtype_list(tff_type) 371 if len(ordered_dtypes) != len(value_list): 372 raise ValueError( 373 'The number of leaves in `tff_type` must equals the length' 374 ' of `value_list`. Found `tff_type` with' 375 f' {len(ordered_dtypes)} leaves and `value_list` of length' 376 f' {len(value_list)}.' 377 ) 378 379 if tff_type.is_tensor(): 380 return value_list[0] 381 elif tff_type.is_struct(): 382 return tff.structure.pack_sequence_as(tff_type, value_list) 383 else: 384 raise ValueError( 385 '`tff_type` must be either tff.TensorType or ' 386 'tff.StructType, reaching here is an internal coding ' 387 'error, please file a bug.' 388 ) 389 390 391def variable_names_from_structure( 392 tff_structure: Union[tff.structure.Struct, tf.Tensor], name: str = 'v' 393) -> list[str]: 394 """Creates a flattened list of variable names for the given structure. 395 396 If the `tff_structure` is a `tf.Tensor`, the name is the `name` parameter if 397 specified, otheriwse a default name: `v`. If `tff_structure` is a 398 `tff.structure.Struct` then '/' is used between inner and outer fields 399 together with the tuple name or index of the element in the tuple. 400 401 Some examples: 402 1. If the `tff_structure` is `<'a'=tf.constant(1.0), 'b'=tf.constant(0.0)>` 403 and name is not specified, the returned variable name list is 404 ['v/a', 'v/b']. 405 2. If the `tff_structure` is `<None=tf.constant(1.0), None=tf.constant(0.0)>` 406 and `name` is `update`, the returned variable name list is 407 ['update/0', 'update/1']. 408 3. If the `tff_structure` is 409 `<'a'=<'b'=tf.constant(1.0), 'c'=tf.constant(0.0)>>` and `name` is 410 `update`, the returned variable name list is ['update/a/b', 'update/a/c']. 411 4. If the `tff_structure` is 412 `<'a'=<'b'=tf.constant(1.0), 'c'=tf.constant(1.0), tf.constant(0.0)>>` and 413 `name` is `update`, the returned variable name list is ['update/a/b', 414 'update/a/c', 'update/a/2']. 415 416 Args: 417 tff_structure: Either a `tff.structure.Struct` or a `tf.Tensor` object. 418 name: The preferred name to use at the top-most level (if not None, must be 419 a string). If `tff_structure` is a `tff.structure.Struct`, the names of 420 the inner fields will be scoped under `name`, e.g. `some_name/field_name`. 421 422 Returns: 423 A flat Python `list` of `str` names. 424 425 Raises: 426 TypeError: If either argument is of the wrong type. 427 """ 428 type_checks.check_type( 429 tff_structure, (tff.structure.Struct, tf.Tensor), name='structure_type' 430 ) 431 type_checks.check_type(name, str, name='name') 432 if isinstance(tff_structure, tf.Tensor): 433 return [name] 434 elif isinstance(tff_structure, tff.structure.Struct): 435 result = [] 436 fields = tff.structure.iter_elements(tff_structure) 437 for index, (field_name, field_type) in enumerate(fields): 438 # Default the name of the element to its index so that we don't wind up 439 # with multiple child fields listed under `/v/` 440 field_name = field_name or str(index) 441 result.extend( 442 variable_names_from_structure( 443 field_type, name=name + '/' + field_name 444 ) 445 ) 446 return result 447 else: 448 raise TypeError( 449 'Cannot create variable names from [{t}] type. Short-hand: {s}'.format( 450 t=type(tff_structure), s=tff_structure 451 ) 452 ) 453 454 455def is_structure_of_allowed_types( 456 structure: Union[ 457 tff.structure.Struct, 458 tf.Tensor, 459 np.ndarray, 460 np.number, 461 int, 462 float, 463 str, 464 bytes, 465 ] 466) -> bool: 467 """Checks if each node in `structure` is an allowed type for serialization.""" 468 flattened_structure = tff.structure.flatten(structure) 469 for item in flattened_structure: 470 if not ( 471 tf.is_tensor(item) 472 or isinstance(item, (np.ndarray, np.number, int, float, str, bytes)) 473 ): 474 return False 475 return True 476 477 478def save_tff_structure_to_checkpoint( 479 tff_structure: Union[tff.structure.Struct, tf.Tensor], 480 ordered_var_names: list[str], 481 output_checkpoint_path: str, 482) -> None: 483 """Saves a TFF structure to a checkpoint file. 484 485 The input `tff_structure` is a either `tff.structure.Struct` or a single 486 `tf.Tensor`. This function saves `tff_structure` to a checkpoint file using 487 variable names supplied via the `ordered_var_names` argument. 488 489 Args: 490 tff_structure: A `tff.structure.Struct` of values or a single value. Each 491 leaf in the structure must be a value serializable to a TensorFlow 492 checkpoint. 493 ordered_var_names: The list of variable names for the values that appear in 494 `tff_structure` after calling `tff.structure.flatten()`. 495 output_checkpoint_path: A string specifying the path to the output 496 checkpoint file. 497 498 Raises: 499 TypeError: If not all leaves in `tff_structure` are of allowed types. 500 ValueError: If the number of `tf.Tensor`s in `tff_structure` does not match 501 the size of `ordered_var_names`. 502 """ 503 if not is_structure_of_allowed_types(tff_structure): 504 raise TypeError( 505 'Not all leaves in `tff_structure` are `tf.Tensor`s, ' 506 '`np.ndarray`s, `np.number`s, or Python scalars. Got: ' 507 f'{tff.structure.map_structure(type, tff_structure)!r})' 508 ) 509 510 tensors = tff.structure.flatten(tff_structure) 511 if len(tensors) != len(ordered_var_names): 512 raise ValueError( 513 'The length of `ordered_var_names` does not match the ' 514 'number of tensors in `tff_structure`:' 515 f'{len(ordered_var_names)} != {len(tensors)}' 516 ) 517 518 tensor_utils.save( 519 output_checkpoint_path, tensor_names=ordered_var_names, tensors=tensors 520 ) 521