xref: /aosp_15_r20/external/federated-compute/fcp/artifact_building/data_spec.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""A class to specify on-device dataset inputs."""
15
16from collections.abc import Callable
17from typing import Any, Optional, Union
18
19import tensorflow as tf
20import tensorflow_federated as tff
21
22from fcp.artifact_building import type_checks
23from fcp.protos import plan_pb2
24
25
26class DataSpec:
27  """A specification of a single dataset input."""
28
29  __slots__ = (
30      '_example_selector_proto',
31      '_preprocessing_fn',
32      '_preprocessing_comp',
33      '_fingerprint',
34  )
35
36  def __init__(
37      self,
38      example_selector_proto: plan_pb2.ExampleSelector,
39      preprocessing_fn: Optional[
40          Callable[[tf.data.Dataset], tf.data.Dataset]
41      ] = None,
42  ):
43    """Constructs a specification of a dataset input.
44
45    Args:
46      example_selector_proto: An instance of `plan_pb2.ExampleSelector` proto.
47      preprocessing_fn: A callable that accepts as an argument the raw input
48        `tf.data.Dataset` with `string`-serialized items, performs any desired
49        preprocessing such as deserialization, filtering, batching, and
50        formatting, and returns the transformed `tf.data.Dataset` as a result.
51        If preprocessing_fn is set to None, it is expected that any client data
52        preprocessing has already been incorporated into the `tff.Computation`
53        that this `DataSpec` is associated with.
54
55    Raises:
56      TypeError: If the types of the arguments are invalid.
57    """
58    type_checks.check_type(
59        example_selector_proto,
60        plan_pb2.ExampleSelector,
61        name='example_selector_proto',
62    )
63    if preprocessing_fn is not None:
64      type_checks.check_callable(preprocessing_fn, name='preprocessing_fn')
65    self._example_selector_proto = example_selector_proto
66    self._preprocessing_fn = preprocessing_fn
67    # Set once self.preprocessing_comp is accessed, as we can't call
68    # tff.computation in __init__.
69    self._preprocessing_comp = None
70
71  @property
72  def example_selector_proto(self) -> plan_pb2.ExampleSelector:
73    return self._example_selector_proto
74
75  @property
76  def preprocessing_fn(
77      self,
78  ) -> Optional[Callable[[tf.data.Dataset], tf.data.Dataset]]:
79    return self._preprocessing_fn
80
81  @property
82  def preprocessing_comp(self) -> tff.Computation:
83    """Returns the preprocessing computation for the input dataset."""
84    if self._preprocessing_comp is None:
85      if self.preprocessing_fn is None:
86        raise ValueError(
87            "DataSpec's preprocessing_fn is None so a "
88            'preprocessing tff.Computation cannot be generated.'
89        )
90      self._preprocessing_comp = tff.tf_computation(
91          self.preprocessing_fn, tff.SequenceType(tf.string)
92      )
93    return self._preprocessing_comp
94
95  @property
96  def type_signature(self) -> tff.Type:
97    """Returns the type signature of the result of the preprocessing_comp.
98
99    Effectively the type or 'spec' of the parsed example from the example store
100    pointed at by `example_selector_proto`.
101    """
102    return self.preprocessing_comp.type_signature.result
103
104
105def is_data_spec_or_structure(x: Any) -> bool:
106  """Returns True iff `x` is either a `DataSpec` or a nested structure of it."""
107  if x is None:
108    return False
109  if isinstance(x, DataSpec):
110    return True
111  try:
112    x = tff.structure.from_container(x)
113    return all(
114        is_data_spec_or_structure(y) for _, y in tff.structure.to_elements(x)
115    )
116  except TypeError:
117    return False
118
119
120def check_data_spec_or_structure(x: Any, name: str):
121  """Raises error iff `x` is not a `DataSpec` or a nested structure of it."""
122  if not is_data_spec_or_structure(x):
123    raise TypeError(
124        f'Expected `{name}` to be a `DataSpec` or a nested '
125        f'structure of it, found {str(x)}.'
126    )
127
128
129NestedDataSpec = Union[DataSpec, dict[str, 'NestedDataSpec']]
130
131
132def generate_example_selector_bytes_list(ds: NestedDataSpec):
133  """Returns an ordered list of the bytes of each DataSpec's example selector.
134
135  The order aligns with the order of a struct given by
136  tff.structure.to_elements().
137
138  Args:
139    ds: A `NestedDataSpec`.
140  """
141  if isinstance(ds, DataSpec):
142    return [ds.example_selector_proto.SerializeToString()]
143  else:
144    ds = tff.structure.from_container(ds)
145    assert isinstance(ds, tff.structure.Struct)
146    data_spec_elements = tff.structure.to_elements(ds)
147    selector_bytes_list = []
148    for _, element in data_spec_elements:
149      selector_bytes_list.extend(generate_example_selector_bytes_list(element))
150    return selector_bytes_list
151