xref: /aosp_15_r20/external/federated-compute/fcp/artifact_building/type_checks.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Helper methods for doing runtime type checks."""
15
16from typing import Any, Optional, Tuple, Type, Union
17
18import tensorflow as tf
19
20
21def _format_name_for_error(name: Optional[Any]) -> str:
22  """Formats an optional object name for `check_*` error messages.
23
24  Args:
25    name: Optional name of the object being checked. If unspecified, will use a
26      placeholder object name instead.
27
28  Returns:
29    A formatted name for the object suitable for including in error messages.
30  """
31  return f'`{name}`' if name else 'argument'
32
33
34def check_type(
35    obj: Any,
36    t: Union[Type[Any], Tuple[Type[Any], ...]],
37    name: Optional[str] = None,
38) -> None:
39  """Checks if an object is an instance of a type.
40
41  Args:
42    obj: The object to check.
43    t: The type to test whether `obj` is an instance or not.
44    name: Optional name of the object being checked. Will be included in the
45      error message if specified.
46
47  Raises:
48    TypeError: If `obj` is not an instance of `t`.
49  """
50  if not isinstance(obj, t):
51    msg_name = _format_name_for_error(name)
52    raise TypeError(
53        f'Expected {msg_name} to be an instance of type {t!r}, but '
54        f'found an instance of type {type(obj)!r}.'
55    )
56
57
58def check_callable(obj: Any, name: Optional[str] = None) -> None:
59  """Checks if an object is a Python callable.
60
61  Args:
62    obj: The object to check.
63    name: Optional name of the object being checked. Will be included in the
64      error message if specified.
65
66  Raises:
67    TypeError: If `obj` is not a Python callable.
68  """
69  if not callable(obj):
70    msg_name = _format_name_for_error(name)
71    raise TypeError(
72        f'Expected {msg_name} to be callable, but found an '
73        f'instance of {type(obj)!r}.'
74    )
75
76
77def check_dataset(
78    obj: Union[
79        tf.data.Dataset, tf.compat.v1.data.Dataset, tf.compat.v2.data.Dataset
80    ],
81    name: Optional[str] = None,
82) -> None:
83  """Checks that the runtime type of the input is a Tensorflow Dataset.
84
85  Tensorflow has many classes which conform to the Dataset API. This method
86  checks each of the known Dataset types.
87
88  Args:
89    obj: The input object to check.
90    name: Optional name of the object being checked. Will be included in the
91      error message if specified.
92  """
93  dataset_types = (
94      tf.data.Dataset,
95      tf.compat.v1.data.Dataset,
96      tf.compat.v2.data.Dataset,
97  )
98  if not isinstance(obj, dataset_types):
99    msg_name = _format_name_for_error(name)
100    raise TypeError(
101        f'Expected {msg_name} to be a Dataset; but found an '
102        f'instance of {type(obj).__name__}.'
103    )
104