1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""A class to specify on-device dataset inputs.""" 15 16from collections.abc import Callable 17from typing import Any, Optional, Union 18 19import tensorflow as tf 20import tensorflow_federated as tff 21 22from fcp.artifact_building import type_checks 23from fcp.protos import plan_pb2 24 25 26class DataSpec: 27 """A specification of a single dataset input.""" 28 29 __slots__ = ( 30 '_example_selector_proto', 31 '_preprocessing_fn', 32 '_preprocessing_comp', 33 '_fingerprint', 34 ) 35 36 def __init__( 37 self, 38 example_selector_proto: plan_pb2.ExampleSelector, 39 preprocessing_fn: Optional[ 40 Callable[[tf.data.Dataset], tf.data.Dataset] 41 ] = None, 42 ): 43 """Constructs a specification of a dataset input. 44 45 Args: 46 example_selector_proto: An instance of `plan_pb2.ExampleSelector` proto. 47 preprocessing_fn: A callable that accepts as an argument the raw input 48 `tf.data.Dataset` with `string`-serialized items, performs any desired 49 preprocessing such as deserialization, filtering, batching, and 50 formatting, and returns the transformed `tf.data.Dataset` as a result. 51 If preprocessing_fn is set to None, it is expected that any client data 52 preprocessing has already been incorporated into the `tff.Computation` 53 that this `DataSpec` is associated with. 54 55 Raises: 56 TypeError: If the types of the arguments are invalid. 57 """ 58 type_checks.check_type( 59 example_selector_proto, 60 plan_pb2.ExampleSelector, 61 name='example_selector_proto', 62 ) 63 if preprocessing_fn is not None: 64 type_checks.check_callable(preprocessing_fn, name='preprocessing_fn') 65 self._example_selector_proto = example_selector_proto 66 self._preprocessing_fn = preprocessing_fn 67 # Set once self.preprocessing_comp is accessed, as we can't call 68 # tff.computation in __init__. 69 self._preprocessing_comp = None 70 71 @property 72 def example_selector_proto(self) -> plan_pb2.ExampleSelector: 73 return self._example_selector_proto 74 75 @property 76 def preprocessing_fn( 77 self, 78 ) -> Optional[Callable[[tf.data.Dataset], tf.data.Dataset]]: 79 return self._preprocessing_fn 80 81 @property 82 def preprocessing_comp(self) -> tff.Computation: 83 """Returns the preprocessing computation for the input dataset.""" 84 if self._preprocessing_comp is None: 85 if self.preprocessing_fn is None: 86 raise ValueError( 87 "DataSpec's preprocessing_fn is None so a " 88 'preprocessing tff.Computation cannot be generated.' 89 ) 90 self._preprocessing_comp = tff.tf_computation( 91 self.preprocessing_fn, tff.SequenceType(tf.string) 92 ) 93 return self._preprocessing_comp 94 95 @property 96 def type_signature(self) -> tff.Type: 97 """Returns the type signature of the result of the preprocessing_comp. 98 99 Effectively the type or 'spec' of the parsed example from the example store 100 pointed at by `example_selector_proto`. 101 """ 102 return self.preprocessing_comp.type_signature.result 103 104 105def is_data_spec_or_structure(x: Any) -> bool: 106 """Returns True iff `x` is either a `DataSpec` or a nested structure of it.""" 107 if x is None: 108 return False 109 if isinstance(x, DataSpec): 110 return True 111 try: 112 x = tff.structure.from_container(x) 113 return all( 114 is_data_spec_or_structure(y) for _, y in tff.structure.to_elements(x) 115 ) 116 except TypeError: 117 return False 118 119 120def check_data_spec_or_structure(x: Any, name: str): 121 """Raises error iff `x` is not a `DataSpec` or a nested structure of it.""" 122 if not is_data_spec_or_structure(x): 123 raise TypeError( 124 f'Expected `{name}` to be a `DataSpec` or a nested ' 125 f'structure of it, found {str(x)}.' 126 ) 127 128 129NestedDataSpec = Union[DataSpec, dict[str, 'NestedDataSpec']] 130 131 132def generate_example_selector_bytes_list(ds: NestedDataSpec): 133 """Returns an ordered list of the bytes of each DataSpec's example selector. 134 135 The order aligns with the order of a struct given by 136 tff.structure.to_elements(). 137 138 Args: 139 ds: A `NestedDataSpec`. 140 """ 141 if isinstance(ds, DataSpec): 142 return [ds.example_selector_proto.SerializeToString()] 143 else: 144 ds = tff.structure.from_container(ds) 145 assert isinstance(ds, tff.structure.Struct) 146 data_spec_elements = tff.structure.to_elements(ds) 147 selector_bytes_list = [] 148 for _, element in data_spec_elements: 149 selector_bytes_list.extend(generate_example_selector_bytes_list(element)) 150 return selector_bytes_list 151