1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Helper methods for doing runtime type checks.""" 15 16from typing import Any, Optional, Tuple, Type, Union 17 18import tensorflow as tf 19 20 21def _format_name_for_error(name: Optional[Any]) -> str: 22 """Formats an optional object name for `check_*` error messages. 23 24 Args: 25 name: Optional name of the object being checked. If unspecified, will use a 26 placeholder object name instead. 27 28 Returns: 29 A formatted name for the object suitable for including in error messages. 30 """ 31 return f'`{name}`' if name else 'argument' 32 33 34def check_type( 35 obj: Any, 36 t: Union[Type[Any], Tuple[Type[Any], ...]], 37 name: Optional[str] = None, 38) -> None: 39 """Checks if an object is an instance of a type. 40 41 Args: 42 obj: The object to check. 43 t: The type to test whether `obj` is an instance or not. 44 name: Optional name of the object being checked. Will be included in the 45 error message if specified. 46 47 Raises: 48 TypeError: If `obj` is not an instance of `t`. 49 """ 50 if not isinstance(obj, t): 51 msg_name = _format_name_for_error(name) 52 raise TypeError( 53 f'Expected {msg_name} to be an instance of type {t!r}, but ' 54 f'found an instance of type {type(obj)!r}.' 55 ) 56 57 58def check_callable(obj: Any, name: Optional[str] = None) -> None: 59 """Checks if an object is a Python callable. 60 61 Args: 62 obj: The object to check. 63 name: Optional name of the object being checked. Will be included in the 64 error message if specified. 65 66 Raises: 67 TypeError: If `obj` is not a Python callable. 68 """ 69 if not callable(obj): 70 msg_name = _format_name_for_error(name) 71 raise TypeError( 72 f'Expected {msg_name} to be callable, but found an ' 73 f'instance of {type(obj)!r}.' 74 ) 75 76 77def check_dataset( 78 obj: Union[ 79 tf.data.Dataset, tf.compat.v1.data.Dataset, tf.compat.v2.data.Dataset 80 ], 81 name: Optional[str] = None, 82) -> None: 83 """Checks that the runtime type of the input is a Tensorflow Dataset. 84 85 Tensorflow has many classes which conform to the Dataset API. This method 86 checks each of the known Dataset types. 87 88 Args: 89 obj: The input object to check. 90 name: Optional name of the object being checked. Will be included in the 91 error message if specified. 92 """ 93 dataset_types = ( 94 tf.data.Dataset, 95 tf.compat.v1.data.Dataset, 96 tf.compat.v2.data.Dataset, 97 ) 98 if not isinstance(obj, dataset_types): 99 msg_name = _format_name_for_error(name) 100 raise TypeError( 101 f'Expected {msg_name} to be a Dataset; but found an ' 102 f'instance of {type(obj).__name__}.' 103 ) 104