1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC 2*14675a02SAndroid Build Coastguard Worker# 3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License"); 4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License. 5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at 6*14675a02SAndroid Build Coastguard Worker# 7*14675a02SAndroid Build Coastguard Worker# http://www.apache.org/licenses/LICENSE-2.0 8*14675a02SAndroid Build Coastguard Worker# 9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software 10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS, 11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and 13*14675a02SAndroid Build Coastguard Worker# limitations under the License. 14*14675a02SAndroid Build Coastguard Worker"""Helper methods for working with demo server checkpoints.""" 15*14675a02SAndroid Build Coastguard Worker 16*14675a02SAndroid Build Coastguard Workerimport collections 17*14675a02SAndroid Build Coastguard Workerfrom collections.abc import Callable, Iterable, Mapping 18*14675a02SAndroid Build Coastguard Workerfrom typing import Any, Optional, Union 19*14675a02SAndroid Build Coastguard Worker 20*14675a02SAndroid Build Coastguard Workerimport numpy as np 21*14675a02SAndroid Build Coastguard Workerimport tensorflow as tf 22*14675a02SAndroid Build Coastguard Workerimport tensorflow_federated as tff 23*14675a02SAndroid Build Coastguard Worker 24*14675a02SAndroid Build Coastguard Workerfrom fcp.artifact_building import artifact_constants 25*14675a02SAndroid Build Coastguard Workerfrom fcp.artifact_building import tensor_utils 26*14675a02SAndroid Build Coastguard Workerfrom fcp.artifact_building import type_checks 27*14675a02SAndroid Build Coastguard Workerfrom fcp.artifact_building import variable_helpers 28*14675a02SAndroid Build Coastguard Workerfrom fcp.protos import plan_pb2 29*14675a02SAndroid Build Coastguard Worker 30*14675a02SAndroid Build Coastguard WorkerSAVE_SERVER_SAVEPOINT_NAME = 'save_server_savepoint' 31*14675a02SAndroid Build Coastguard Worker 32*14675a02SAndroid Build Coastguard Worker 33*14675a02SAndroid Build Coastguard Workerdef create_server_checkpoint_vars_and_savepoint( 34*14675a02SAndroid Build Coastguard Worker *, 35*14675a02SAndroid Build Coastguard Worker server_state_type: tff.StructType, 36*14675a02SAndroid Build Coastguard Worker server_metrics_type: Optional[tff.StructType] = None, 37*14675a02SAndroid Build Coastguard Worker write_metrics_to_checkpoint: bool = True, 38*14675a02SAndroid Build Coastguard Worker additional_checkpoint_metadata_var_fn: Optional[ 39*14675a02SAndroid Build Coastguard Worker Callable[[tff.StructType, tff.StructType, bool], list[tf.Variable]] 40*14675a02SAndroid Build Coastguard Worker ] = None, 41*14675a02SAndroid Build Coastguard Worker) -> tuple[ 42*14675a02SAndroid Build Coastguard Worker list[tf.Variable], 43*14675a02SAndroid Build Coastguard Worker list[tf.Variable], 44*14675a02SAndroid Build Coastguard Worker list[tf.Variable], 45*14675a02SAndroid Build Coastguard Worker plan_pb2.CheckpointOp, 46*14675a02SAndroid Build Coastguard Worker]: 47*14675a02SAndroid Build Coastguard Worker """Creates tf.Variables for a server checkpoint and the associated savepoint. 48*14675a02SAndroid Build Coastguard Worker 49*14675a02SAndroid Build Coastguard Worker The variables and the associated saver are constructed in the default graph. 50*14675a02SAndroid Build Coastguard Worker 51*14675a02SAndroid Build Coastguard Worker For now, only `server_state_type` is required. If metrics are to be saved in 52*14675a02SAndroid Build Coastguard Worker the server checkpoint, `server_metrics_type` and `server_result_type` must 53*14675a02SAndroid Build Coastguard Worker be provided. `server_state_type` refers to the server state portion of the 54*14675a02SAndroid Build Coastguard Worker checkpoint and is used in the `Restore` op of the savepoint. The 55*14675a02SAndroid Build Coastguard Worker `server_metrics_type` refers to the metrics saved in the checkpoint, and is 56*14675a02SAndroid Build Coastguard Worker not used in the `Restore` op of the savepoint. `server_result_type` refers to 57*14675a02SAndroid Build Coastguard Worker the complete round result structure stored in the checkpoint for a round. 58*14675a02SAndroid Build Coastguard Worker 59*14675a02SAndroid Build Coastguard Worker Args: 60*14675a02SAndroid Build Coastguard Worker server_state_type: A `tff.Type` with the type signature of the state. This 61*14675a02SAndroid Build Coastguard Worker is used to construct the server state variable names stored in the 62*14675a02SAndroid Build Coastguard Worker checkpoint and is used to create the metadata variables for the checkpoint 63*14675a02SAndroid Build Coastguard Worker if `server_result_type` is not provided. 64*14675a02SAndroid Build Coastguard Worker server_metrics_type: Optional. A `tff.Type` with the type signature of the 65*14675a02SAndroid Build Coastguard Worker metrics. If provided, this is used to construct the metric variable names 66*14675a02SAndroid Build Coastguard Worker that are stored in the checkpoint. 67*14675a02SAndroid Build Coastguard Worker write_metrics_to_checkpoint: If False, revert to legacy behavior where 68*14675a02SAndroid Build Coastguard Worker metrics and other non-state values were handled by post-processing 69*14675a02SAndroid Build Coastguard Worker separate from the outputted checkpoint. 70*14675a02SAndroid Build Coastguard Worker additional_checkpoint_metadata_var_fn: An optional method that takes in the 71*14675a02SAndroid Build Coastguard Worker server_state_type, server_metrics_type, and write_metrics_to_checkpoint to 72*14675a02SAndroid Build Coastguard Worker produce additional metadata variables. 73*14675a02SAndroid Build Coastguard Worker 74*14675a02SAndroid Build Coastguard Worker Returns: 75*14675a02SAndroid Build Coastguard Worker A tuple `(state_vars, metric_vars, metadata_vars, savepoint)`: 76*14675a02SAndroid Build Coastguard Worker - `state_vars` is a Python `list` of variables that hold the state. 77*14675a02SAndroid Build Coastguard Worker - `metric_vars` is a Python `list` of variables that hold the metrics. 78*14675a02SAndroid Build Coastguard Worker - `metadata_vars` is a Python `list` of variables that hold optional 79*14675a02SAndroid Build Coastguard Worker metadata. 80*14675a02SAndroid Build Coastguard Worker - `savepoint` is the associated savepoint, i.e., an instance of 81*14675a02SAndroid Build Coastguard Worker `plan_pb2.CheckpointOp` with a saver configured for saving the 82*14675a02SAndroid Build Coastguard Worker `state_vars`, `metadata_vars`, and, if write_metrics_to_checkpoint is 83*14675a02SAndroid Build Coastguard Worker True, `metric_vars`, and restoring the `state_vars` and 84*14675a02SAndroid Build Coastguard Worker `metadata_vars`. 85*14675a02SAndroid Build Coastguard Worker """ 86*14675a02SAndroid Build Coastguard Worker has_metrics = False 87*14675a02SAndroid Build Coastguard Worker metric_vars = [] 88*14675a02SAndroid Build Coastguard Worker save_tensor_name = None 89*14675a02SAndroid Build Coastguard Worker type_checks.check_type(server_state_type, tff.Type, name='server_state_type') 90*14675a02SAndroid Build Coastguard Worker state_vars = variable_helpers.create_vars_for_tff_type( 91*14675a02SAndroid Build Coastguard Worker server_state_type, artifact_constants.SERVER_STATE_VAR_PREFIX 92*14675a02SAndroid Build Coastguard Worker ) 93*14675a02SAndroid Build Coastguard Worker var_names = list(map(tensor_utils.bare_name, state_vars)) 94*14675a02SAndroid Build Coastguard Worker metadata_vars = [] 95*14675a02SAndroid Build Coastguard Worker if server_metrics_type is not None: 96*14675a02SAndroid Build Coastguard Worker type_checks.check_type( 97*14675a02SAndroid Build Coastguard Worker server_metrics_type, tff.Type, name='server_metrics_type' 98*14675a02SAndroid Build Coastguard Worker ) 99*14675a02SAndroid Build Coastguard Worker metric_vars = variable_helpers.create_vars_for_tff_type( 100*14675a02SAndroid Build Coastguard Worker server_metrics_type, artifact_constants.SERVER_METRICS_VAR_PREFIX 101*14675a02SAndroid Build Coastguard Worker ) 102*14675a02SAndroid Build Coastguard Worker if additional_checkpoint_metadata_var_fn: 103*14675a02SAndroid Build Coastguard Worker metadata_vars = additional_checkpoint_metadata_var_fn( 104*14675a02SAndroid Build Coastguard Worker state_vars, metric_vars, write_metrics_to_checkpoint 105*14675a02SAndroid Build Coastguard Worker ) 106*14675a02SAndroid Build Coastguard Worker 107*14675a02SAndroid Build Coastguard Worker has_metrics = bool(tff.structure.flatten(server_metrics_type)) 108*14675a02SAndroid Build Coastguard Worker if has_metrics and write_metrics_to_checkpoint: 109*14675a02SAndroid Build Coastguard Worker var_names.extend(list(map(tensor_utils.bare_name, metric_vars))) 110*14675a02SAndroid Build Coastguard Worker 111*14675a02SAndroid Build Coastguard Worker temp_saver_for_all_vars = create_deterministic_saver( 112*14675a02SAndroid Build Coastguard Worker var_list=state_vars + metadata_vars + metric_vars, 113*14675a02SAndroid Build Coastguard Worker name=SAVE_SERVER_SAVEPOINT_NAME, 114*14675a02SAndroid Build Coastguard Worker ) 115*14675a02SAndroid Build Coastguard Worker temp_saver_def = temp_saver_for_all_vars.as_saver_def() 116*14675a02SAndroid Build Coastguard Worker save_tensor_name = temp_saver_def.save_tensor_name 117*14675a02SAndroid Build Coastguard Worker else: 118*14675a02SAndroid Build Coastguard Worker if additional_checkpoint_metadata_var_fn: 119*14675a02SAndroid Build Coastguard Worker metadata_vars = additional_checkpoint_metadata_var_fn( 120*14675a02SAndroid Build Coastguard Worker state_vars, None, write_metrics_to_checkpoint 121*14675a02SAndroid Build Coastguard Worker ) 122*14675a02SAndroid Build Coastguard Worker 123*14675a02SAndroid Build Coastguard Worker saver = create_deterministic_saver( 124*14675a02SAndroid Build Coastguard Worker var_list=state_vars + metadata_vars, 125*14675a02SAndroid Build Coastguard Worker name='{}_savepoint'.format(artifact_constants.SERVER_STATE_VAR_PREFIX), 126*14675a02SAndroid Build Coastguard Worker ) 127*14675a02SAndroid Build Coastguard Worker savepoint = plan_pb2.CheckpointOp() 128*14675a02SAndroid Build Coastguard Worker savepoint.saver_def.CopyFrom(saver.as_saver_def()) 129*14675a02SAndroid Build Coastguard Worker 130*14675a02SAndroid Build Coastguard Worker if save_tensor_name is not None: 131*14675a02SAndroid Build Coastguard Worker # Replace the save_tensor_name to the one in 132*14675a02SAndroid Build Coastguard Worker # temp_saver_for_all_vars so that we are additionally saving metrics vars 133*14675a02SAndroid Build Coastguard Worker # in the checkpoint that don't need to be restored as part of the input 134*14675a02SAndroid Build Coastguard Worker # computation state. 135*14675a02SAndroid Build Coastguard Worker # Once we create the server GraphDef, we will edit the GraphDef directly 136*14675a02SAndroid Build Coastguard Worker # to ensure the input filename links to the filename tensor from the 137*14675a02SAndroid Build Coastguard Worker # `savepoint`. 138*14675a02SAndroid Build Coastguard Worker savepoint.saver_def.save_tensor_name = save_tensor_name 139*14675a02SAndroid Build Coastguard Worker return state_vars, metric_vars, metadata_vars, savepoint 140*14675a02SAndroid Build Coastguard Worker 141*14675a02SAndroid Build Coastguard Worker 142*14675a02SAndroid Build Coastguard Workerdef create_state_vars_and_savepoint( 143*14675a02SAndroid Build Coastguard Worker type_spec: variable_helpers.AllowedTffTypes, name: str 144*14675a02SAndroid Build Coastguard Worker) -> tuple[list[tf.Variable], plan_pb2.CheckpointOp]: 145*14675a02SAndroid Build Coastguard Worker """Creates state variables and their savepoint as a `plan_pb2.CheckpointOp`. 146*14675a02SAndroid Build Coastguard Worker 147*14675a02SAndroid Build Coastguard Worker The variables and the associated saver are constructed in the default graph. 148*14675a02SAndroid Build Coastguard Worker 149*14675a02SAndroid Build Coastguard Worker Args: 150*14675a02SAndroid Build Coastguard Worker type_spec: An instance of `tff.Type` with the type signature of the state. 151*14675a02SAndroid Build Coastguard Worker name: The string to use as a basis for naming the vars and the saver. The 152*14675a02SAndroid Build Coastguard Worker vars will be under `${name}_state`, and saver under `${name}_savepoint`. 153*14675a02SAndroid Build Coastguard Worker 154*14675a02SAndroid Build Coastguard Worker Returns: 155*14675a02SAndroid Build Coastguard Worker A tuple `(vars, savepoint)`, where `vars` is a Python `list` of variables 156*14675a02SAndroid Build Coastguard Worker that hold the state, and `savepoint` is the associated savepoint, i.e., 157*14675a02SAndroid Build Coastguard Worker an instance of `plan_pb2.CheckpointOp` with a saver configured for saving 158*14675a02SAndroid Build Coastguard Worker and restoring the `vars`. 159*14675a02SAndroid Build Coastguard Worker 160*14675a02SAndroid Build Coastguard Worker Raises: 161*14675a02SAndroid Build Coastguard Worker ValueError: If the name is empty. 162*14675a02SAndroid Build Coastguard Worker """ 163*14675a02SAndroid Build Coastguard Worker state_vars, saver = create_state_vars_and_saver(type_spec, name) 164*14675a02SAndroid Build Coastguard Worker savepoint = plan_pb2.CheckpointOp() 165*14675a02SAndroid Build Coastguard Worker savepoint.saver_def.CopyFrom(saver.as_saver_def()) 166*14675a02SAndroid Build Coastguard Worker return state_vars, savepoint 167*14675a02SAndroid Build Coastguard Worker 168*14675a02SAndroid Build Coastguard Worker 169*14675a02SAndroid Build Coastguard Workerdef create_state_vars_and_saver( 170*14675a02SAndroid Build Coastguard Worker type_spec: variable_helpers.AllowedTffTypes, name: str 171*14675a02SAndroid Build Coastguard Worker) -> tuple[list[tf.Variable], tf.compat.v1.train.Saver]: 172*14675a02SAndroid Build Coastguard Worker """Creates state variables and the associated saver. 173*14675a02SAndroid Build Coastguard Worker 174*14675a02SAndroid Build Coastguard Worker The variables and the associated saver are constructed in the default graph. 175*14675a02SAndroid Build Coastguard Worker 176*14675a02SAndroid Build Coastguard Worker Args: 177*14675a02SAndroid Build Coastguard Worker type_spec: An instance of `tff.Type` with the type signature of the state. 178*14675a02SAndroid Build Coastguard Worker name: The string to use as a basis for naming the vars and the saver. The 179*14675a02SAndroid Build Coastguard Worker vars will be under `${name}_state`, and saver under `${name}_savepoint`. 180*14675a02SAndroid Build Coastguard Worker 181*14675a02SAndroid Build Coastguard Worker Returns: 182*14675a02SAndroid Build Coastguard Worker A tuple `(vars, savepoint)`, where `vars` is a Python `list` of variables 183*14675a02SAndroid Build Coastguard Worker that hold the state, and `savepoint` is the associated 184*14675a02SAndroid Build Coastguard Worker `tf.compat.v1.train.Saver`. 185*14675a02SAndroid Build Coastguard Worker 186*14675a02SAndroid Build Coastguard Worker Raises: 187*14675a02SAndroid Build Coastguard Worker ValueError: If the name is empty. 188*14675a02SAndroid Build Coastguard Worker """ 189*14675a02SAndroid Build Coastguard Worker type_checks.check_type(type_spec, tff.Type, name='type_spec') 190*14675a02SAndroid Build Coastguard Worker type_checks.check_type(name, str, name='name') 191*14675a02SAndroid Build Coastguard Worker if not name: 192*14675a02SAndroid Build Coastguard Worker raise ValueError('Name cannot be empty.') 193*14675a02SAndroid Build Coastguard Worker state_vars = variable_helpers.create_vars_for_tff_type(type_spec, name) 194*14675a02SAndroid Build Coastguard Worker saver = create_deterministic_saver( 195*14675a02SAndroid Build Coastguard Worker state_vars, name='{}_savepoint'.format(name) 196*14675a02SAndroid Build Coastguard Worker ) 197*14675a02SAndroid Build Coastguard Worker return state_vars, saver 198*14675a02SAndroid Build Coastguard Worker 199*14675a02SAndroid Build Coastguard Worker 200*14675a02SAndroid Build Coastguard Workerdef restore_tensors_from_savepoint( 201*14675a02SAndroid Build Coastguard Worker tensor_specs: Iterable[tf.TensorSpec], filepath_tensor: tf.Tensor 202*14675a02SAndroid Build Coastguard Worker) -> list[tf.Tensor]: 203*14675a02SAndroid Build Coastguard Worker """Restores tensors from a checkpoint designated by a tensor filepath. 204*14675a02SAndroid Build Coastguard Worker 205*14675a02SAndroid Build Coastguard Worker Args: 206*14675a02SAndroid Build Coastguard Worker tensor_specs: A `list` of `tf.TensorSpec`s with the names and dtypes of the 207*14675a02SAndroid Build Coastguard Worker tensors to restore. 208*14675a02SAndroid Build Coastguard Worker filepath_tensor: A placeholder tensor that contains file names with a given 209*14675a02SAndroid Build Coastguard Worker pattern. 210*14675a02SAndroid Build Coastguard Worker 211*14675a02SAndroid Build Coastguard Worker Returns: 212*14675a02SAndroid Build Coastguard Worker A list of restored tensors. 213*14675a02SAndroid Build Coastguard Worker """ 214*14675a02SAndroid Build Coastguard Worker return [ 215*14675a02SAndroid Build Coastguard Worker tensor_utils.restore( 216*14675a02SAndroid Build Coastguard Worker filepath_tensor, tensor_utils.bare_name(spec.name), spec.dtype 217*14675a02SAndroid Build Coastguard Worker ) 218*14675a02SAndroid Build Coastguard Worker for spec in tensor_specs 219*14675a02SAndroid Build Coastguard Worker ] 220*14675a02SAndroid Build Coastguard Worker 221*14675a02SAndroid Build Coastguard Worker 222*14675a02SAndroid Build Coastguard Workerdef create_deterministic_saver( 223*14675a02SAndroid Build Coastguard Worker var_list: Union[Iterable[tf.Variable], Mapping[str, tf.Variable]], 224*14675a02SAndroid Build Coastguard Worker *args, 225*14675a02SAndroid Build Coastguard Worker **kwargs, 226*14675a02SAndroid Build Coastguard Worker) -> tf.compat.v1.train.Saver: 227*14675a02SAndroid Build Coastguard Worker """Creates a `tf.compat.v1.Saver` that is deterministic. 228*14675a02SAndroid Build Coastguard Worker 229*14675a02SAndroid Build Coastguard Worker This method sorts the `var_list` to ensure a deterministic ordering which 230*14675a02SAndroid Build Coastguard Worker in turn ensures a deterministic checkpoint. 231*14675a02SAndroid Build Coastguard Worker 232*14675a02SAndroid Build Coastguard Worker Uses `tf.compat.v1.train.SaverDef.V1` version for writing checkpoints. 233*14675a02SAndroid Build Coastguard Worker 234*14675a02SAndroid Build Coastguard Worker Args: 235*14675a02SAndroid Build Coastguard Worker var_list: An `Iterable` or `str` keyed `Mapping` of `tf.Variables`. In the 236*14675a02SAndroid Build Coastguard Worker case of a `dict`, the keys become the names of the checkpoint variables 237*14675a02SAndroid Build Coastguard Worker (rather than reading the names off the `tf.Variable` values). 238*14675a02SAndroid Build Coastguard Worker *args: Positional arguments forwarded to the `tf.compat.v1.train.Saver` 239*14675a02SAndroid Build Coastguard Worker constructor. 240*14675a02SAndroid Build Coastguard Worker **kwargs: Keyword arguments forwarded to the `tf.compat.v1.train.Saver` 241*14675a02SAndroid Build Coastguard Worker constructor. 242*14675a02SAndroid Build Coastguard Worker 243*14675a02SAndroid Build Coastguard Worker Returns: 244*14675a02SAndroid Build Coastguard Worker A `tf.compat.v1.train.Saver` instance. 245*14675a02SAndroid Build Coastguard Worker """ 246*14675a02SAndroid Build Coastguard Worker if isinstance(var_list, collections.abc.Mapping): 247*14675a02SAndroid Build Coastguard Worker determinisic_names = collections.OrderedDict(sorted(var_list.items())) 248*14675a02SAndroid Build Coastguard Worker elif isinstance(var_list, collections.abc.Iterable): 249*14675a02SAndroid Build Coastguard Worker determinisic_names = sorted(var_list, key=lambda v: v.name) 250*14675a02SAndroid Build Coastguard Worker else: 251*14675a02SAndroid Build Coastguard Worker raise ValueError( 252*14675a02SAndroid Build Coastguard Worker 'Do not know how to make a deterministic saver for ' 253*14675a02SAndroid Build Coastguard Worker '`var_list` of type [{t}]. Must be a Mapping or Sequence'.format( 254*14675a02SAndroid Build Coastguard Worker t=type(var_list) 255*14675a02SAndroid Build Coastguard Worker ) 256*14675a02SAndroid Build Coastguard Worker ) 257*14675a02SAndroid Build Coastguard Worker return tf.compat.v1.train.Saver( 258*14675a02SAndroid Build Coastguard Worker determinisic_names, 259*14675a02SAndroid Build Coastguard Worker write_version=tf.compat.v1.train.SaverDef.V1, 260*14675a02SAndroid Build Coastguard Worker *args, 261*14675a02SAndroid Build Coastguard Worker **kwargs, 262*14675a02SAndroid Build Coastguard Worker ) 263*14675a02SAndroid Build Coastguard Worker 264*14675a02SAndroid Build Coastguard Worker 265*14675a02SAndroid Build Coastguard Workerdef tff_type_to_dtype_list( 266*14675a02SAndroid Build Coastguard Worker tff_type: variable_helpers.AllowedTffTypes, 267*14675a02SAndroid Build Coastguard Worker) -> list[tf.DType]: 268*14675a02SAndroid Build Coastguard Worker """Creates a flat list of `tf.DType`s for tensors in a `tff.Type`. 269*14675a02SAndroid Build Coastguard Worker 270*14675a02SAndroid Build Coastguard Worker Args: 271*14675a02SAndroid Build Coastguard Worker tff_type: Either a `tff.StructType`, `tff.FederatedType`, or a 272*14675a02SAndroid Build Coastguard Worker `tff.TensorType` object. 273*14675a02SAndroid Build Coastguard Worker 274*14675a02SAndroid Build Coastguard Worker Returns: 275*14675a02SAndroid Build Coastguard Worker A flat list of `tf.DType`s. 276*14675a02SAndroid Build Coastguard Worker """ 277*14675a02SAndroid Build Coastguard Worker type_checks.check_type( 278*14675a02SAndroid Build Coastguard Worker tff_type, (tff.TensorType, tff.FederatedType, tff.StructType) 279*14675a02SAndroid Build Coastguard Worker ) 280*14675a02SAndroid Build Coastguard Worker if isinstance(tff_type, tff.TensorType): 281*14675a02SAndroid Build Coastguard Worker return [tff_type.dtype] 282*14675a02SAndroid Build Coastguard Worker elif isinstance(tff_type, tff.FederatedType): 283*14675a02SAndroid Build Coastguard Worker return tff_type_to_dtype_list(tff_type.member) 284*14675a02SAndroid Build Coastguard Worker else: # tff.StructType 285*14675a02SAndroid Build Coastguard Worker elem_list = [] 286*14675a02SAndroid Build Coastguard Worker for elem_type in tff_type: 287*14675a02SAndroid Build Coastguard Worker elem_list.extend(tff_type_to_dtype_list(elem_type)) 288*14675a02SAndroid Build Coastguard Worker return elem_list 289*14675a02SAndroid Build Coastguard Worker 290*14675a02SAndroid Build Coastguard Worker 291*14675a02SAndroid Build Coastguard Workerdef tff_type_to_tensor_spec_list( 292*14675a02SAndroid Build Coastguard Worker tff_type: variable_helpers.AllowedTffTypes, 293*14675a02SAndroid Build Coastguard Worker) -> list[tf.TensorSpec]: 294*14675a02SAndroid Build Coastguard Worker """Creates a flat list of tensor specs for tensors in a `tff.Type`. 295*14675a02SAndroid Build Coastguard Worker 296*14675a02SAndroid Build Coastguard Worker Args: 297*14675a02SAndroid Build Coastguard Worker tff_type: Either a `tff.StructType`, `tff.FederatedType` or a 298*14675a02SAndroid Build Coastguard Worker `tff.TensorType` object. 299*14675a02SAndroid Build Coastguard Worker 300*14675a02SAndroid Build Coastguard Worker Returns: 301*14675a02SAndroid Build Coastguard Worker A flat list of `tf.TensorSpec`s. 302*14675a02SAndroid Build Coastguard Worker """ 303*14675a02SAndroid Build Coastguard Worker type_checks.check_type( 304*14675a02SAndroid Build Coastguard Worker tff_type, (tff.TensorType, tff.FederatedType, tff.StructType) 305*14675a02SAndroid Build Coastguard Worker ) 306*14675a02SAndroid Build Coastguard Worker if isinstance(tff_type, tff.TensorType): 307*14675a02SAndroid Build Coastguard Worker return [tf.TensorSpec(tff_type.shape, dtype=tff_type.dtype)] 308*14675a02SAndroid Build Coastguard Worker elif isinstance(tff_type, tff.FederatedType): 309*14675a02SAndroid Build Coastguard Worker return tff_type_to_tensor_spec_list(tff_type.member) 310*14675a02SAndroid Build Coastguard Worker else: # tff.StructType 311*14675a02SAndroid Build Coastguard Worker elem_list = [] 312*14675a02SAndroid Build Coastguard Worker for elem_type in tff_type: 313*14675a02SAndroid Build Coastguard Worker elem_list.extend(tff_type_to_tensor_spec_list(elem_type)) 314*14675a02SAndroid Build Coastguard Worker return elem_list 315*14675a02SAndroid Build Coastguard Worker 316*14675a02SAndroid Build Coastguard Worker 317*14675a02SAndroid Build Coastguard Workerdef pack_tff_value( 318*14675a02SAndroid Build Coastguard Worker tff_type: variable_helpers.AllowedTffTypes, value_list: Any 319*14675a02SAndroid Build Coastguard Worker) -> Any: 320*14675a02SAndroid Build Coastguard Worker """Packs a list of values into a shape specified by a `tff.Type`. 321*14675a02SAndroid Build Coastguard Worker 322*14675a02SAndroid Build Coastguard Worker Args: 323*14675a02SAndroid Build Coastguard Worker tff_type: Either a `tff.StructType`, `tff.FederatedType`, or a 324*14675a02SAndroid Build Coastguard Worker `tff.TensorType` object. 325*14675a02SAndroid Build Coastguard Worker value_list: A flat list of `tf.Tensor` or `CheckpointTensorReference`. 326*14675a02SAndroid Build Coastguard Worker 327*14675a02SAndroid Build Coastguard Worker Returns: 328*14675a02SAndroid Build Coastguard Worker A Python container with a structure consistent with a `tff.Type`. 329*14675a02SAndroid Build Coastguard Worker 330*14675a02SAndroid Build Coastguard Worker Raises: 331*14675a02SAndroid Build Coastguard Worker ValueError: If the number of leaves in `tff_type` does not match the length 332*14675a02SAndroid Build Coastguard Worker of `value_list`, or `tff_type` is of a disallowed type. 333*14675a02SAndroid Build Coastguard Worker """ 334*14675a02SAndroid Build Coastguard Worker type_checks.check_type( 335*14675a02SAndroid Build Coastguard Worker tff_type, (tff.TensorType, tff.FederatedType, tff.StructType) 336*14675a02SAndroid Build Coastguard Worker ) 337*14675a02SAndroid Build Coastguard Worker 338*14675a02SAndroid Build Coastguard Worker # We must "unwrap" any FederatedTypes because the 339*14675a02SAndroid Build Coastguard Worker # `tff.structure.pack_sequence_as` call below will fail to recurse into them. 340*14675a02SAndroid Build Coastguard Worker # Instead, we remove all the FederatedTypes, because we're only trying to 341*14675a02SAndroid Build Coastguard Worker # build up a Python tree structure that matches the struct/tensor types from a 342*14675a02SAndroid Build Coastguard Worker # list of values. 343*14675a02SAndroid Build Coastguard Worker def remove_federated_types( 344*14675a02SAndroid Build Coastguard Worker type_spec: tff.Type, 345*14675a02SAndroid Build Coastguard Worker ) -> Union[tff.StructType, tff.TensorType]: 346*14675a02SAndroid Build Coastguard Worker """Removes `FederatedType` from a type tree, returning a new tree.""" 347*14675a02SAndroid Build Coastguard Worker if type_spec.is_tensor(): 348*14675a02SAndroid Build Coastguard Worker return type_spec 349*14675a02SAndroid Build Coastguard Worker elif type_spec.is_federated(): 350*14675a02SAndroid Build Coastguard Worker return type_spec.member 351*14675a02SAndroid Build Coastguard Worker elif type_spec.is_struct(): 352*14675a02SAndroid Build Coastguard Worker return tff.StructType( 353*14675a02SAndroid Build Coastguard Worker (elem_name, remove_federated_types(elem_type)) 354*14675a02SAndroid Build Coastguard Worker for elem_name, elem_type in tff.structure.iter_elements(type_spec) 355*14675a02SAndroid Build Coastguard Worker ) 356*14675a02SAndroid Build Coastguard Worker else: 357*14675a02SAndroid Build Coastguard Worker raise ValueError( 358*14675a02SAndroid Build Coastguard Worker 'Must be either tff.TensorType, tff.FederatedType, or tff.StructType.' 359*14675a02SAndroid Build Coastguard Worker f' Got a {type(type_spec)}' 360*14675a02SAndroid Build Coastguard Worker ) 361*14675a02SAndroid Build Coastguard Worker 362*14675a02SAndroid Build Coastguard Worker try: 363*14675a02SAndroid Build Coastguard Worker tff_type = remove_federated_types(tff_type) 364*14675a02SAndroid Build Coastguard Worker except ValueError as e: 365*14675a02SAndroid Build Coastguard Worker raise ValueError( 366*14675a02SAndroid Build Coastguard Worker '`tff_type` is not packable, see earlier error. ' 367*14675a02SAndroid Build Coastguard Worker f'Attempted to pack type: {tff_type}' 368*14675a02SAndroid Build Coastguard Worker ) from e 369*14675a02SAndroid Build Coastguard Worker 370*14675a02SAndroid Build Coastguard Worker ordered_dtypes = tff_type_to_dtype_list(tff_type) 371*14675a02SAndroid Build Coastguard Worker if len(ordered_dtypes) != len(value_list): 372*14675a02SAndroid Build Coastguard Worker raise ValueError( 373*14675a02SAndroid Build Coastguard Worker 'The number of leaves in `tff_type` must equals the length' 374*14675a02SAndroid Build Coastguard Worker ' of `value_list`. Found `tff_type` with' 375*14675a02SAndroid Build Coastguard Worker f' {len(ordered_dtypes)} leaves and `value_list` of length' 376*14675a02SAndroid Build Coastguard Worker f' {len(value_list)}.' 377*14675a02SAndroid Build Coastguard Worker ) 378*14675a02SAndroid Build Coastguard Worker 379*14675a02SAndroid Build Coastguard Worker if tff_type.is_tensor(): 380*14675a02SAndroid Build Coastguard Worker return value_list[0] 381*14675a02SAndroid Build Coastguard Worker elif tff_type.is_struct(): 382*14675a02SAndroid Build Coastguard Worker return tff.structure.pack_sequence_as(tff_type, value_list) 383*14675a02SAndroid Build Coastguard Worker else: 384*14675a02SAndroid Build Coastguard Worker raise ValueError( 385*14675a02SAndroid Build Coastguard Worker '`tff_type` must be either tff.TensorType or ' 386*14675a02SAndroid Build Coastguard Worker 'tff.StructType, reaching here is an internal coding ' 387*14675a02SAndroid Build Coastguard Worker 'error, please file a bug.' 388*14675a02SAndroid Build Coastguard Worker ) 389*14675a02SAndroid Build Coastguard Worker 390*14675a02SAndroid Build Coastguard Worker 391*14675a02SAndroid Build Coastguard Workerdef variable_names_from_structure( 392*14675a02SAndroid Build Coastguard Worker tff_structure: Union[tff.structure.Struct, tf.Tensor], name: str = 'v' 393*14675a02SAndroid Build Coastguard Worker) -> list[str]: 394*14675a02SAndroid Build Coastguard Worker """Creates a flattened list of variable names for the given structure. 395*14675a02SAndroid Build Coastguard Worker 396*14675a02SAndroid Build Coastguard Worker If the `tff_structure` is a `tf.Tensor`, the name is the `name` parameter if 397*14675a02SAndroid Build Coastguard Worker specified, otheriwse a default name: `v`. If `tff_structure` is a 398*14675a02SAndroid Build Coastguard Worker `tff.structure.Struct` then '/' is used between inner and outer fields 399*14675a02SAndroid Build Coastguard Worker together with the tuple name or index of the element in the tuple. 400*14675a02SAndroid Build Coastguard Worker 401*14675a02SAndroid Build Coastguard Worker Some examples: 402*14675a02SAndroid Build Coastguard Worker 1. If the `tff_structure` is `<'a'=tf.constant(1.0), 'b'=tf.constant(0.0)>` 403*14675a02SAndroid Build Coastguard Worker and name is not specified, the returned variable name list is 404*14675a02SAndroid Build Coastguard Worker ['v/a', 'v/b']. 405*14675a02SAndroid Build Coastguard Worker 2. If the `tff_structure` is `<None=tf.constant(1.0), None=tf.constant(0.0)>` 406*14675a02SAndroid Build Coastguard Worker and `name` is `update`, the returned variable name list is 407*14675a02SAndroid Build Coastguard Worker ['update/0', 'update/1']. 408*14675a02SAndroid Build Coastguard Worker 3. If the `tff_structure` is 409*14675a02SAndroid Build Coastguard Worker `<'a'=<'b'=tf.constant(1.0), 'c'=tf.constant(0.0)>>` and `name` is 410*14675a02SAndroid Build Coastguard Worker `update`, the returned variable name list is ['update/a/b', 'update/a/c']. 411*14675a02SAndroid Build Coastguard Worker 4. If the `tff_structure` is 412*14675a02SAndroid Build Coastguard Worker `<'a'=<'b'=tf.constant(1.0), 'c'=tf.constant(1.0), tf.constant(0.0)>>` and 413*14675a02SAndroid Build Coastguard Worker `name` is `update`, the returned variable name list is ['update/a/b', 414*14675a02SAndroid Build Coastguard Worker 'update/a/c', 'update/a/2']. 415*14675a02SAndroid Build Coastguard Worker 416*14675a02SAndroid Build Coastguard Worker Args: 417*14675a02SAndroid Build Coastguard Worker tff_structure: Either a `tff.structure.Struct` or a `tf.Tensor` object. 418*14675a02SAndroid Build Coastguard Worker name: The preferred name to use at the top-most level (if not None, must be 419*14675a02SAndroid Build Coastguard Worker a string). If `tff_structure` is a `tff.structure.Struct`, the names of 420*14675a02SAndroid Build Coastguard Worker the inner fields will be scoped under `name`, e.g. `some_name/field_name`. 421*14675a02SAndroid Build Coastguard Worker 422*14675a02SAndroid Build Coastguard Worker Returns: 423*14675a02SAndroid Build Coastguard Worker A flat Python `list` of `str` names. 424*14675a02SAndroid Build Coastguard Worker 425*14675a02SAndroid Build Coastguard Worker Raises: 426*14675a02SAndroid Build Coastguard Worker TypeError: If either argument is of the wrong type. 427*14675a02SAndroid Build Coastguard Worker """ 428*14675a02SAndroid Build Coastguard Worker type_checks.check_type( 429*14675a02SAndroid Build Coastguard Worker tff_structure, (tff.structure.Struct, tf.Tensor), name='structure_type' 430*14675a02SAndroid Build Coastguard Worker ) 431*14675a02SAndroid Build Coastguard Worker type_checks.check_type(name, str, name='name') 432*14675a02SAndroid Build Coastguard Worker if isinstance(tff_structure, tf.Tensor): 433*14675a02SAndroid Build Coastguard Worker return [name] 434*14675a02SAndroid Build Coastguard Worker elif isinstance(tff_structure, tff.structure.Struct): 435*14675a02SAndroid Build Coastguard Worker result = [] 436*14675a02SAndroid Build Coastguard Worker fields = tff.structure.iter_elements(tff_structure) 437*14675a02SAndroid Build Coastguard Worker for index, (field_name, field_type) in enumerate(fields): 438*14675a02SAndroid Build Coastguard Worker # Default the name of the element to its index so that we don't wind up 439*14675a02SAndroid Build Coastguard Worker # with multiple child fields listed under `/v/` 440*14675a02SAndroid Build Coastguard Worker field_name = field_name or str(index) 441*14675a02SAndroid Build Coastguard Worker result.extend( 442*14675a02SAndroid Build Coastguard Worker variable_names_from_structure( 443*14675a02SAndroid Build Coastguard Worker field_type, name=name + '/' + field_name 444*14675a02SAndroid Build Coastguard Worker ) 445*14675a02SAndroid Build Coastguard Worker ) 446*14675a02SAndroid Build Coastguard Worker return result 447*14675a02SAndroid Build Coastguard Worker else: 448*14675a02SAndroid Build Coastguard Worker raise TypeError( 449*14675a02SAndroid Build Coastguard Worker 'Cannot create variable names from [{t}] type. Short-hand: {s}'.format( 450*14675a02SAndroid Build Coastguard Worker t=type(tff_structure), s=tff_structure 451*14675a02SAndroid Build Coastguard Worker ) 452*14675a02SAndroid Build Coastguard Worker ) 453*14675a02SAndroid Build Coastguard Worker 454*14675a02SAndroid Build Coastguard Worker 455*14675a02SAndroid Build Coastguard Workerdef is_structure_of_allowed_types( 456*14675a02SAndroid Build Coastguard Worker structure: Union[ 457*14675a02SAndroid Build Coastguard Worker tff.structure.Struct, 458*14675a02SAndroid Build Coastguard Worker tf.Tensor, 459*14675a02SAndroid Build Coastguard Worker np.ndarray, 460*14675a02SAndroid Build Coastguard Worker np.number, 461*14675a02SAndroid Build Coastguard Worker int, 462*14675a02SAndroid Build Coastguard Worker float, 463*14675a02SAndroid Build Coastguard Worker str, 464*14675a02SAndroid Build Coastguard Worker bytes, 465*14675a02SAndroid Build Coastguard Worker ] 466*14675a02SAndroid Build Coastguard Worker) -> bool: 467*14675a02SAndroid Build Coastguard Worker """Checks if each node in `structure` is an allowed type for serialization.""" 468*14675a02SAndroid Build Coastguard Worker flattened_structure = tff.structure.flatten(structure) 469*14675a02SAndroid Build Coastguard Worker for item in flattened_structure: 470*14675a02SAndroid Build Coastguard Worker if not ( 471*14675a02SAndroid Build Coastguard Worker tf.is_tensor(item) 472*14675a02SAndroid Build Coastguard Worker or isinstance(item, (np.ndarray, np.number, int, float, str, bytes)) 473*14675a02SAndroid Build Coastguard Worker ): 474*14675a02SAndroid Build Coastguard Worker return False 475*14675a02SAndroid Build Coastguard Worker return True 476*14675a02SAndroid Build Coastguard Worker 477*14675a02SAndroid Build Coastguard Worker 478*14675a02SAndroid Build Coastguard Workerdef save_tff_structure_to_checkpoint( 479*14675a02SAndroid Build Coastguard Worker tff_structure: Union[tff.structure.Struct, tf.Tensor], 480*14675a02SAndroid Build Coastguard Worker ordered_var_names: list[str], 481*14675a02SAndroid Build Coastguard Worker output_checkpoint_path: str, 482*14675a02SAndroid Build Coastguard Worker) -> None: 483*14675a02SAndroid Build Coastguard Worker """Saves a TFF structure to a checkpoint file. 484*14675a02SAndroid Build Coastguard Worker 485*14675a02SAndroid Build Coastguard Worker The input `tff_structure` is a either `tff.structure.Struct` or a single 486*14675a02SAndroid Build Coastguard Worker `tf.Tensor`. This function saves `tff_structure` to a checkpoint file using 487*14675a02SAndroid Build Coastguard Worker variable names supplied via the `ordered_var_names` argument. 488*14675a02SAndroid Build Coastguard Worker 489*14675a02SAndroid Build Coastguard Worker Args: 490*14675a02SAndroid Build Coastguard Worker tff_structure: A `tff.structure.Struct` of values or a single value. Each 491*14675a02SAndroid Build Coastguard Worker leaf in the structure must be a value serializable to a TensorFlow 492*14675a02SAndroid Build Coastguard Worker checkpoint. 493*14675a02SAndroid Build Coastguard Worker ordered_var_names: The list of variable names for the values that appear in 494*14675a02SAndroid Build Coastguard Worker `tff_structure` after calling `tff.structure.flatten()`. 495*14675a02SAndroid Build Coastguard Worker output_checkpoint_path: A string specifying the path to the output 496*14675a02SAndroid Build Coastguard Worker checkpoint file. 497*14675a02SAndroid Build Coastguard Worker 498*14675a02SAndroid Build Coastguard Worker Raises: 499*14675a02SAndroid Build Coastguard Worker TypeError: If not all leaves in `tff_structure` are of allowed types. 500*14675a02SAndroid Build Coastguard Worker ValueError: If the number of `tf.Tensor`s in `tff_structure` does not match 501*14675a02SAndroid Build Coastguard Worker the size of `ordered_var_names`. 502*14675a02SAndroid Build Coastguard Worker """ 503*14675a02SAndroid Build Coastguard Worker if not is_structure_of_allowed_types(tff_structure): 504*14675a02SAndroid Build Coastguard Worker raise TypeError( 505*14675a02SAndroid Build Coastguard Worker 'Not all leaves in `tff_structure` are `tf.Tensor`s, ' 506*14675a02SAndroid Build Coastguard Worker '`np.ndarray`s, `np.number`s, or Python scalars. Got: ' 507*14675a02SAndroid Build Coastguard Worker f'{tff.structure.map_structure(type, tff_structure)!r})' 508*14675a02SAndroid Build Coastguard Worker ) 509*14675a02SAndroid Build Coastguard Worker 510*14675a02SAndroid Build Coastguard Worker tensors = tff.structure.flatten(tff_structure) 511*14675a02SAndroid Build Coastguard Worker if len(tensors) != len(ordered_var_names): 512*14675a02SAndroid Build Coastguard Worker raise ValueError( 513*14675a02SAndroid Build Coastguard Worker 'The length of `ordered_var_names` does not match the ' 514*14675a02SAndroid Build Coastguard Worker 'number of tensors in `tff_structure`:' 515*14675a02SAndroid Build Coastguard Worker f'{len(ordered_var_names)} != {len(tensors)}' 516*14675a02SAndroid Build Coastguard Worker ) 517*14675a02SAndroid Build Coastguard Worker 518*14675a02SAndroid Build Coastguard Worker tensor_utils.save( 519*14675a02SAndroid Build Coastguard Worker output_checkpoint_path, tensor_names=ordered_var_names, tensors=tensors 520*14675a02SAndroid Build Coastguard Worker ) 521