xref: /aosp_15_r20/external/federated-compute/fcp/artifact_building/proto_helpers.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC
2*14675a02SAndroid Build Coastguard Worker#
3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License");
4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License.
5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at
6*14675a02SAndroid Build Coastguard Worker#
7*14675a02SAndroid Build Coastguard Worker#      http://www.apache.org/licenses/LICENSE-2.0
8*14675a02SAndroid Build Coastguard Worker#
9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software
10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS,
11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and
13*14675a02SAndroid Build Coastguard Worker# limitations under the License.
14*14675a02SAndroid Build Coastguard Worker"""Helper methods for proto creation logic."""
15*14675a02SAndroid Build Coastguard Worker
16*14675a02SAndroid Build Coastguard Workerfrom typing import Optional
17*14675a02SAndroid Build Coastguard Worker
18*14675a02SAndroid Build Coastguard Workerimport tensorflow as tf
19*14675a02SAndroid Build Coastguard Workerimport tensorflow_federated as tff
20*14675a02SAndroid Build Coastguard Worker
21*14675a02SAndroid Build Coastguard Workerfrom fcp.artifact_building import tensor_utils
22*14675a02SAndroid Build Coastguard Workerfrom fcp.artifact_building import type_checks
23*14675a02SAndroid Build Coastguard Workerfrom fcp.protos import plan_pb2
24*14675a02SAndroid Build Coastguard Worker
25*14675a02SAndroid Build Coastguard Worker
26*14675a02SAndroid Build Coastguard Workerdef make_tensor_spec_from_tensor(
27*14675a02SAndroid Build Coastguard Worker    t: tf.Tensor, shape_hint: Optional[tf.TensorShape] = None
28*14675a02SAndroid Build Coastguard Worker) -> tf.TensorSpec:
29*14675a02SAndroid Build Coastguard Worker  """Creates a `TensorSpec` from Tensor w/ optional shape hint.
30*14675a02SAndroid Build Coastguard Worker
31*14675a02SAndroid Build Coastguard Worker  Args:
32*14675a02SAndroid Build Coastguard Worker    t: A `tf.Tensor` instance to be used to create a `TensorSpec`.
33*14675a02SAndroid Build Coastguard Worker    shape_hint: A `tf.TensorShape` that provides a fully defined shape in the
34*14675a02SAndroid Build Coastguard Worker      case that `t` is partially defined. If `t` has a fully defined shape,
35*14675a02SAndroid Build Coastguard Worker      `shape_hint` is ignored. `shape_hint` must be compatible with the
36*14675a02SAndroid Build Coastguard Worker      partially defined shape of `t`.
37*14675a02SAndroid Build Coastguard Worker
38*14675a02SAndroid Build Coastguard Worker  Returns:
39*14675a02SAndroid Build Coastguard Worker    A `tf.TensorSpec` instance corresponding to the input `tf.Tensor`.
40*14675a02SAndroid Build Coastguard Worker
41*14675a02SAndroid Build Coastguard Worker  Raises:
42*14675a02SAndroid Build Coastguard Worker    NotImplementedError: If the input `tf.Tensor` type is not supported.
43*14675a02SAndroid Build Coastguard Worker    TypeError: if `shape_hint` is not `None` and is incompatible with the
44*14675a02SAndroid Build Coastguard Worker      runtime shape of `t`.
45*14675a02SAndroid Build Coastguard Worker  """
46*14675a02SAndroid Build Coastguard Worker  if not tf.is_tensor(t):
47*14675a02SAndroid Build Coastguard Worker    raise NotImplementedError(
48*14675a02SAndroid Build Coastguard Worker        'Cannot handle type {t}: {v}'.format(t=type(t), v=t)
49*14675a02SAndroid Build Coastguard Worker    )
50*14675a02SAndroid Build Coastguard Worker  derived_shape = tf.TensorShape(t.shape)
51*14675a02SAndroid Build Coastguard Worker  if not derived_shape.is_fully_defined() and shape_hint is not None:
52*14675a02SAndroid Build Coastguard Worker    if derived_shape.is_compatible_with(shape_hint):
53*14675a02SAndroid Build Coastguard Worker      shape = shape_hint
54*14675a02SAndroid Build Coastguard Worker    else:
55*14675a02SAndroid Build Coastguard Worker      raise TypeError(
56*14675a02SAndroid Build Coastguard Worker          'shape_hint is not compatible with tensor ('
57*14675a02SAndroid Build Coastguard Worker          f'{shape_hint} vs {derived_shape})'
58*14675a02SAndroid Build Coastguard Worker      )
59*14675a02SAndroid Build Coastguard Worker  else:
60*14675a02SAndroid Build Coastguard Worker    shape = derived_shape
61*14675a02SAndroid Build Coastguard Worker  return tf.TensorSpec(shape, t.dtype, name=t.name)
62*14675a02SAndroid Build Coastguard Worker
63*14675a02SAndroid Build Coastguard Worker
64*14675a02SAndroid Build Coastguard Workerdef make_measurement(
65*14675a02SAndroid Build Coastguard Worker    t: tf.Tensor, name: str, tff_type: tff.types.TensorType
66*14675a02SAndroid Build Coastguard Worker) -> plan_pb2.Measurement:
67*14675a02SAndroid Build Coastguard Worker  """Creates a `plan_pb.Measurement` descriptor for a tensor.
68*14675a02SAndroid Build Coastguard Worker
69*14675a02SAndroid Build Coastguard Worker  Args:
70*14675a02SAndroid Build Coastguard Worker    t: A tensor to create the measurement for.
71*14675a02SAndroid Build Coastguard Worker    name: The name of the measurement (e.g. 'server/loss').
72*14675a02SAndroid Build Coastguard Worker    tff_type: The `tff.Type` of the measurement.
73*14675a02SAndroid Build Coastguard Worker
74*14675a02SAndroid Build Coastguard Worker  Returns:
75*14675a02SAndroid Build Coastguard Worker    An instance of `plan_pb.Measurement`.
76*14675a02SAndroid Build Coastguard Worker
77*14675a02SAndroid Build Coastguard Worker  Raises:
78*14675a02SAndroid Build Coastguard Worker    ValueError: If the `dtype`s or `shape`s of the provided tensor and TFF type
79*14675a02SAndroid Build Coastguard Worker      do not match.
80*14675a02SAndroid Build Coastguard Worker  """
81*14675a02SAndroid Build Coastguard Worker  type_checks.check_type(tff_type, tff.types.TensorType)
82*14675a02SAndroid Build Coastguard Worker  if tff_type.dtype != t.dtype:
83*14675a02SAndroid Build Coastguard Worker    raise ValueError(
84*14675a02SAndroid Build Coastguard Worker        f'`tff_type.dtype`: {tff_type.dtype} does not match '
85*14675a02SAndroid Build Coastguard Worker        f"provided tensor's dtype: {t.dtype}."
86*14675a02SAndroid Build Coastguard Worker    )
87*14675a02SAndroid Build Coastguard Worker  if tff_type.shape.is_fully_defined() and t.shape.is_fully_defined():
88*14675a02SAndroid Build Coastguard Worker    if tff_type.shape.as_list() != t.shape.as_list():
89*14675a02SAndroid Build Coastguard Worker      raise ValueError(
90*14675a02SAndroid Build Coastguard Worker          f'`tff_type.shape`: {tff_type.shape} does not match '
91*14675a02SAndroid Build Coastguard Worker          f"provided tensor's shape: {t.shape}."
92*14675a02SAndroid Build Coastguard Worker      )
93*14675a02SAndroid Build Coastguard Worker  return plan_pb2.Measurement(
94*14675a02SAndroid Build Coastguard Worker      read_op_name=t.name,
95*14675a02SAndroid Build Coastguard Worker      name=name,
96*14675a02SAndroid Build Coastguard Worker      tff_type=tff.types.serialize_type(tff_type).SerializeToString(),
97*14675a02SAndroid Build Coastguard Worker  )
98*14675a02SAndroid Build Coastguard Worker
99*14675a02SAndroid Build Coastguard Worker
100*14675a02SAndroid Build Coastguard Workerdef make_metric(v: tf.Variable, stat_name_prefix: str) -> plan_pb2.Metric:
101*14675a02SAndroid Build Coastguard Worker  """Creates a `plan_pb.Metric` descriptor for a resource variable.
102*14675a02SAndroid Build Coastguard Worker
103*14675a02SAndroid Build Coastguard Worker  The stat name is formed by stripping the leading `..../` prefix and any
104*14675a02SAndroid Build Coastguard Worker  colon-based suffix.
105*14675a02SAndroid Build Coastguard Worker
106*14675a02SAndroid Build Coastguard Worker  Args:
107*14675a02SAndroid Build Coastguard Worker    v: A variable to create the metric descriptor for.
108*14675a02SAndroid Build Coastguard Worker    stat_name_prefix: The prefix (string) to use in formulating a stat name,
109*14675a02SAndroid Build Coastguard Worker      excluding the trailing slash `/` (added automatically).
110*14675a02SAndroid Build Coastguard Worker
111*14675a02SAndroid Build Coastguard Worker  Returns:
112*14675a02SAndroid Build Coastguard Worker    An instance of `plan_pb.Metric` for `v`.
113*14675a02SAndroid Build Coastguard Worker
114*14675a02SAndroid Build Coastguard Worker  Raises:
115*14675a02SAndroid Build Coastguard Worker    TypeError: If the arguments are of the wrong types.
116*14675a02SAndroid Build Coastguard Worker    ValueError: If the arguments are malformed (e.g., no leading name prefix).
117*14675a02SAndroid Build Coastguard Worker  """
118*14675a02SAndroid Build Coastguard Worker  type_checks.check_type(stat_name_prefix, str, name='stat_name_prefix')
119*14675a02SAndroid Build Coastguard Worker  if not hasattr(v, 'read_value'):
120*14675a02SAndroid Build Coastguard Worker    raise TypeError('Expected a resource variable, found {!r}.'.format(type(v)))
121*14675a02SAndroid Build Coastguard Worker  bare_name = tensor_utils.bare_name(v.name)
122*14675a02SAndroid Build Coastguard Worker  if '/' not in bare_name:
123*14675a02SAndroid Build Coastguard Worker    raise ValueError(
124*14675a02SAndroid Build Coastguard Worker        'Expected a prefix in the name, found none in {}.'.format(bare_name)
125*14675a02SAndroid Build Coastguard Worker    )
126*14675a02SAndroid Build Coastguard Worker  stat_name = '{}/{}'.format(
127*14675a02SAndroid Build Coastguard Worker      stat_name_prefix, bare_name[(bare_name.find('/') + 1) :]
128*14675a02SAndroid Build Coastguard Worker  )
129*14675a02SAndroid Build Coastguard Worker  return plan_pb2.Metric(variable_name=v.read_value().name, stat_name=stat_name)
130