1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC 2*14675a02SAndroid Build Coastguard Worker# 3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License"); 4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License. 5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at 6*14675a02SAndroid Build Coastguard Worker# 7*14675a02SAndroid Build Coastguard Worker# http://www.apache.org/licenses/LICENSE-2.0 8*14675a02SAndroid Build Coastguard Worker# 9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software 10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS, 11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and 13*14675a02SAndroid Build Coastguard Worker# limitations under the License. 14*14675a02SAndroid Build Coastguard Worker"""Helper methods for proto creation logic.""" 15*14675a02SAndroid Build Coastguard Worker 16*14675a02SAndroid Build Coastguard Workerfrom typing import Optional 17*14675a02SAndroid Build Coastguard Worker 18*14675a02SAndroid Build Coastguard Workerimport tensorflow as tf 19*14675a02SAndroid Build Coastguard Workerimport tensorflow_federated as tff 20*14675a02SAndroid Build Coastguard Worker 21*14675a02SAndroid Build Coastguard Workerfrom fcp.artifact_building import tensor_utils 22*14675a02SAndroid Build Coastguard Workerfrom fcp.artifact_building import type_checks 23*14675a02SAndroid Build Coastguard Workerfrom fcp.protos import plan_pb2 24*14675a02SAndroid Build Coastguard Worker 25*14675a02SAndroid Build Coastguard Worker 26*14675a02SAndroid Build Coastguard Workerdef make_tensor_spec_from_tensor( 27*14675a02SAndroid Build Coastguard Worker t: tf.Tensor, shape_hint: Optional[tf.TensorShape] = None 28*14675a02SAndroid Build Coastguard Worker) -> tf.TensorSpec: 29*14675a02SAndroid Build Coastguard Worker """Creates a `TensorSpec` from Tensor w/ optional shape hint. 30*14675a02SAndroid Build Coastguard Worker 31*14675a02SAndroid Build Coastguard Worker Args: 32*14675a02SAndroid Build Coastguard Worker t: A `tf.Tensor` instance to be used to create a `TensorSpec`. 33*14675a02SAndroid Build Coastguard Worker shape_hint: A `tf.TensorShape` that provides a fully defined shape in the 34*14675a02SAndroid Build Coastguard Worker case that `t` is partially defined. If `t` has a fully defined shape, 35*14675a02SAndroid Build Coastguard Worker `shape_hint` is ignored. `shape_hint` must be compatible with the 36*14675a02SAndroid Build Coastguard Worker partially defined shape of `t`. 37*14675a02SAndroid Build Coastguard Worker 38*14675a02SAndroid Build Coastguard Worker Returns: 39*14675a02SAndroid Build Coastguard Worker A `tf.TensorSpec` instance corresponding to the input `tf.Tensor`. 40*14675a02SAndroid Build Coastguard Worker 41*14675a02SAndroid Build Coastguard Worker Raises: 42*14675a02SAndroid Build Coastguard Worker NotImplementedError: If the input `tf.Tensor` type is not supported. 43*14675a02SAndroid Build Coastguard Worker TypeError: if `shape_hint` is not `None` and is incompatible with the 44*14675a02SAndroid Build Coastguard Worker runtime shape of `t`. 45*14675a02SAndroid Build Coastguard Worker """ 46*14675a02SAndroid Build Coastguard Worker if not tf.is_tensor(t): 47*14675a02SAndroid Build Coastguard Worker raise NotImplementedError( 48*14675a02SAndroid Build Coastguard Worker 'Cannot handle type {t}: {v}'.format(t=type(t), v=t) 49*14675a02SAndroid Build Coastguard Worker ) 50*14675a02SAndroid Build Coastguard Worker derived_shape = tf.TensorShape(t.shape) 51*14675a02SAndroid Build Coastguard Worker if not derived_shape.is_fully_defined() and shape_hint is not None: 52*14675a02SAndroid Build Coastguard Worker if derived_shape.is_compatible_with(shape_hint): 53*14675a02SAndroid Build Coastguard Worker shape = shape_hint 54*14675a02SAndroid Build Coastguard Worker else: 55*14675a02SAndroid Build Coastguard Worker raise TypeError( 56*14675a02SAndroid Build Coastguard Worker 'shape_hint is not compatible with tensor (' 57*14675a02SAndroid Build Coastguard Worker f'{shape_hint} vs {derived_shape})' 58*14675a02SAndroid Build Coastguard Worker ) 59*14675a02SAndroid Build Coastguard Worker else: 60*14675a02SAndroid Build Coastguard Worker shape = derived_shape 61*14675a02SAndroid Build Coastguard Worker return tf.TensorSpec(shape, t.dtype, name=t.name) 62*14675a02SAndroid Build Coastguard Worker 63*14675a02SAndroid Build Coastguard Worker 64*14675a02SAndroid Build Coastguard Workerdef make_measurement( 65*14675a02SAndroid Build Coastguard Worker t: tf.Tensor, name: str, tff_type: tff.types.TensorType 66*14675a02SAndroid Build Coastguard Worker) -> plan_pb2.Measurement: 67*14675a02SAndroid Build Coastguard Worker """Creates a `plan_pb.Measurement` descriptor for a tensor. 68*14675a02SAndroid Build Coastguard Worker 69*14675a02SAndroid Build Coastguard Worker Args: 70*14675a02SAndroid Build Coastguard Worker t: A tensor to create the measurement for. 71*14675a02SAndroid Build Coastguard Worker name: The name of the measurement (e.g. 'server/loss'). 72*14675a02SAndroid Build Coastguard Worker tff_type: The `tff.Type` of the measurement. 73*14675a02SAndroid Build Coastguard Worker 74*14675a02SAndroid Build Coastguard Worker Returns: 75*14675a02SAndroid Build Coastguard Worker An instance of `plan_pb.Measurement`. 76*14675a02SAndroid Build Coastguard Worker 77*14675a02SAndroid Build Coastguard Worker Raises: 78*14675a02SAndroid Build Coastguard Worker ValueError: If the `dtype`s or `shape`s of the provided tensor and TFF type 79*14675a02SAndroid Build Coastguard Worker do not match. 80*14675a02SAndroid Build Coastguard Worker """ 81*14675a02SAndroid Build Coastguard Worker type_checks.check_type(tff_type, tff.types.TensorType) 82*14675a02SAndroid Build Coastguard Worker if tff_type.dtype != t.dtype: 83*14675a02SAndroid Build Coastguard Worker raise ValueError( 84*14675a02SAndroid Build Coastguard Worker f'`tff_type.dtype`: {tff_type.dtype} does not match ' 85*14675a02SAndroid Build Coastguard Worker f"provided tensor's dtype: {t.dtype}." 86*14675a02SAndroid Build Coastguard Worker ) 87*14675a02SAndroid Build Coastguard Worker if tff_type.shape.is_fully_defined() and t.shape.is_fully_defined(): 88*14675a02SAndroid Build Coastguard Worker if tff_type.shape.as_list() != t.shape.as_list(): 89*14675a02SAndroid Build Coastguard Worker raise ValueError( 90*14675a02SAndroid Build Coastguard Worker f'`tff_type.shape`: {tff_type.shape} does not match ' 91*14675a02SAndroid Build Coastguard Worker f"provided tensor's shape: {t.shape}." 92*14675a02SAndroid Build Coastguard Worker ) 93*14675a02SAndroid Build Coastguard Worker return plan_pb2.Measurement( 94*14675a02SAndroid Build Coastguard Worker read_op_name=t.name, 95*14675a02SAndroid Build Coastguard Worker name=name, 96*14675a02SAndroid Build Coastguard Worker tff_type=tff.types.serialize_type(tff_type).SerializeToString(), 97*14675a02SAndroid Build Coastguard Worker ) 98*14675a02SAndroid Build Coastguard Worker 99*14675a02SAndroid Build Coastguard Worker 100*14675a02SAndroid Build Coastguard Workerdef make_metric(v: tf.Variable, stat_name_prefix: str) -> plan_pb2.Metric: 101*14675a02SAndroid Build Coastguard Worker """Creates a `plan_pb.Metric` descriptor for a resource variable. 102*14675a02SAndroid Build Coastguard Worker 103*14675a02SAndroid Build Coastguard Worker The stat name is formed by stripping the leading `..../` prefix and any 104*14675a02SAndroid Build Coastguard Worker colon-based suffix. 105*14675a02SAndroid Build Coastguard Worker 106*14675a02SAndroid Build Coastguard Worker Args: 107*14675a02SAndroid Build Coastguard Worker v: A variable to create the metric descriptor for. 108*14675a02SAndroid Build Coastguard Worker stat_name_prefix: The prefix (string) to use in formulating a stat name, 109*14675a02SAndroid Build Coastguard Worker excluding the trailing slash `/` (added automatically). 110*14675a02SAndroid Build Coastguard Worker 111*14675a02SAndroid Build Coastguard Worker Returns: 112*14675a02SAndroid Build Coastguard Worker An instance of `plan_pb.Metric` for `v`. 113*14675a02SAndroid Build Coastguard Worker 114*14675a02SAndroid Build Coastguard Worker Raises: 115*14675a02SAndroid Build Coastguard Worker TypeError: If the arguments are of the wrong types. 116*14675a02SAndroid Build Coastguard Worker ValueError: If the arguments are malformed (e.g., no leading name prefix). 117*14675a02SAndroid Build Coastguard Worker """ 118*14675a02SAndroid Build Coastguard Worker type_checks.check_type(stat_name_prefix, str, name='stat_name_prefix') 119*14675a02SAndroid Build Coastguard Worker if not hasattr(v, 'read_value'): 120*14675a02SAndroid Build Coastguard Worker raise TypeError('Expected a resource variable, found {!r}.'.format(type(v))) 121*14675a02SAndroid Build Coastguard Worker bare_name = tensor_utils.bare_name(v.name) 122*14675a02SAndroid Build Coastguard Worker if '/' not in bare_name: 123*14675a02SAndroid Build Coastguard Worker raise ValueError( 124*14675a02SAndroid Build Coastguard Worker 'Expected a prefix in the name, found none in {}.'.format(bare_name) 125*14675a02SAndroid Build Coastguard Worker ) 126*14675a02SAndroid Build Coastguard Worker stat_name = '{}/{}'.format( 127*14675a02SAndroid Build Coastguard Worker stat_name_prefix, bare_name[(bare_name.find('/') + 1) :] 128*14675a02SAndroid Build Coastguard Worker ) 129*14675a02SAndroid Build Coastguard Worker return plan_pb2.Metric(variable_name=v.read_value().name, stat_name=stat_name) 130