1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for managing tf.data user-defined functions.""" 16 17import warnings 18 19from tensorflow.python.data.util import nest 20from tensorflow.python.data.util import structure 21from tensorflow.python.eager import context 22from tensorflow.python.eager import def_function 23from tensorflow.python.eager import function as eager_function 24 25from tensorflow.python.framework import function 26from tensorflow.python.framework import ops 27from tensorflow.python.ops import script_ops 28from tensorflow.python.util import function_utils 29from tensorflow.python.util import lazy_loader 30from tensorflow.python.util import variable_utils 31 32autograph = lazy_loader.LazyLoader( 33 "autograph", globals(), 34 "tensorflow.python.autograph.impl.api") 35# TODO(mdan): Create a public API for this. 36autograph_ctx = lazy_loader.LazyLoader( 37 "autograph_ctx", globals(), 38 "tensorflow.python.autograph.core.ag_ctx") 39dataset_ops = lazy_loader.LazyLoader( 40 "dataset_ops", globals(), 41 "tensorflow.python.data.ops.dataset_ops") 42 43 44def _should_pack(arg): 45 """Determines whether the caller needs to pack the argument in a tuple. 46 47 If user-defined function returns a list of tensors, `nest.flatten()` and 48 `ops.convert_to_tensor()` and would conspire to attempt to stack those tensors 49 into a single tensor because the tf.data version of `nest.flatten()` does 50 not recurse into lists. Since it is more likely that the list arose from 51 returning the result of an operation (such as `tf.numpy_function()`) that 52 returns a list of not-necessarily-stackable tensors, we treat the returned 53 value as a `tuple` instead. A user wishing to pack the return value into a 54 single tensor can use an explicit `tf.stack()` before returning. 55 56 Args: 57 arg: argument to check 58 59 Returns: 60 Indication of whether the caller needs to pack the argument in a tuple. 61 """ 62 return isinstance(arg, list) 63 64 65def _should_unpack(arg): 66 """Determines whether the caller needs to unpack the argument from a tuple. 67 68 Args: 69 arg: argument to check 70 71 Returns: 72 Indication of whether the caller needs to unpack the argument from a tuple. 73 """ 74 return type(arg) is tuple # pylint: disable=unidiomatic-typecheck 75 76 77class StructuredFunctionWrapper(): 78 """A function wrapper that supports structured arguments and return values.""" 79 80 def __init__(self, 81 func, 82 transformation_name, 83 dataset=None, 84 input_classes=None, 85 input_shapes=None, 86 input_types=None, 87 input_structure=None, 88 add_to_graph=True, 89 use_legacy_function=False, 90 defun_kwargs=None): 91 """Creates a new `StructuredFunctionWrapper` for the given function. 92 93 Args: 94 func: A function from a (nested) structure to another (nested) structure. 95 transformation_name: Human-readable name of the transformation in which 96 this function is being instantiated, for error messages. 97 dataset: (Optional.) A `tf.data.Dataset`. If given, the structure of this 98 dataset will be assumed as the structure for `func` arguments; otherwise 99 `input_classes`, `input_shapes`, and `input_types` must be defined. 100 input_classes: (Optional.) A (nested) structure of `type`. If given, this 101 argument defines the Python types for `func` arguments. 102 input_shapes: (Optional.) A (nested) structure of `tf.TensorShape`. If 103 given, this argument defines the shapes and structure for `func` 104 arguments. 105 input_types: (Optional.) A (nested) structure of `tf.DType`. If given, 106 this argument defines the element types and structure for `func` 107 arguments. 108 input_structure: (Optional.) A `Structure` object. If given, this argument 109 defines the element types and structure for `func` arguments. 110 add_to_graph: (Optional.) If `True`, the function will be added to the 111 default graph, if it exists. 112 use_legacy_function: (Optional.) A boolean that determines whether the 113 function be created using `tensorflow.python.eager.function.defun` 114 (default behavior) or `tensorflow.python.framework.function.Defun` 115 (legacy behavior). 116 defun_kwargs: (Optional.) A dictionary mapping string argument names to 117 values. If supplied, will be passed to `function` as keyword arguments. 118 119 Raises: 120 ValueError: If an invalid combination of `dataset`, `input_classes`, 121 `input_shapes`, and `input_types` is passed. 122 """ 123 # pylint: disable=protected-access 124 if input_structure is None: 125 if dataset is None: 126 if input_classes is None or input_shapes is None or input_types is None: 127 raise ValueError("Either `dataset`, `input_structure` or all of " 128 "`input_classes`, `input_shapes`, and `input_types` " 129 "must be specified.") 130 self._input_structure = structure.convert_legacy_structure( 131 input_types, input_shapes, input_classes) 132 else: 133 if not (input_classes is None and input_shapes is None and 134 input_types is None): 135 raise ValueError("Either `dataset`, `input_structure` or all of " 136 "`input_classes`, `input_shapes`, and `input_types` " 137 "must be specified.") 138 self._input_structure = dataset.element_spec 139 else: 140 if not (dataset is None and input_classes is None and 141 input_shapes is None and input_types is None): 142 raise ValueError("Either `dataset`, `input_structure`, or all of " 143 "`input_classes`, `input_shapes`, and `input_types` " 144 "must be specified.") 145 self._input_structure = input_structure 146 147 self._func = func 148 149 if defun_kwargs is None: 150 defun_kwargs = {} 151 152 readable_transformation_name = transformation_name.replace( 153 ".", "_")[:-2] if len(transformation_name) > 2 else "" 154 155 func_name = "_".join( 156 [readable_transformation_name, 157 function_utils.get_func_name(func)]) 158 # Sanitize function name to remove symbols that interfere with graph 159 # construction. 160 for symbol in ["<", ">", "\\", "'", " "]: 161 func_name = func_name.replace(symbol, "") 162 163 ag_ctx = autograph_ctx.control_status_ctx() 164 165 def wrapper_helper(*args): 166 """Wrapper for passing nested structures to and from tf.data functions.""" 167 nested_args = structure.from_compatible_tensor_list( 168 self._input_structure, args) 169 if not _should_unpack(nested_args): 170 nested_args = (nested_args,) 171 ret = autograph.tf_convert(self._func, ag_ctx)(*nested_args) 172 ret = variable_utils.convert_variables_to_tensors(ret) 173 if _should_pack(ret): 174 ret = tuple(ret) 175 176 try: 177 self._output_structure = structure.type_spec_from_value(ret) 178 except (ValueError, TypeError) as e: 179 raise TypeError(f"Unsupported return value from function passed to " 180 f"{transformation_name}: {ret}.") from e 181 return ret 182 183 def trace_legacy_function(defun_kwargs): 184 185 @function.Defun(*structure.get_flat_tensor_types(self._input_structure), 186 **defun_kwargs) 187 def wrapped_fn(*args): 188 ret = wrapper_helper(*args) 189 return structure.to_tensor_list(self._output_structure, ret) 190 191 return lambda: wrapped_fn 192 193 def trace_py_function(defun_kwargs): 194 # First we trace the function to infer the output structure. 195 @eager_function.defun_with_attributes( 196 input_signature=structure.get_flat_tensor_specs( 197 self._input_structure), 198 autograph=False, 199 attributes=defun_kwargs) 200 def unused(*args): # pylint: disable=missing-docstring,unused-variable 201 ret = wrapper_helper(*args) 202 ret = structure.to_tensor_list(self._output_structure, ret) 203 return [ops.convert_to_tensor(t) for t in ret] 204 205 _ = unused.get_concrete_function() 206 207 def py_function_wrapper(*args): 208 nested_args = structure.from_compatible_tensor_list( 209 self._input_structure, args) 210 if not _should_unpack(nested_args): 211 nested_args = (nested_args,) 212 ret = self._func(*nested_args) 213 if _should_pack(ret): 214 ret = tuple(ret) 215 ret = structure.to_tensor_list(self._output_structure, ret) 216 return [ops.convert_to_tensor(t) for t in ret] 217 218 # Next we trace the function wrapped in `eager_py_func` to force eager 219 # execution. 220 @eager_function.defun_with_attributes( 221 input_signature=structure.get_flat_tensor_specs( 222 self._input_structure), 223 autograph=False, 224 attributes=defun_kwargs) 225 def wrapped_fn(*args): # pylint: disable=missing-docstring 226 return script_ops.eager_py_func( 227 py_function_wrapper, args, 228 structure.get_flat_tensor_types(self._output_structure)) 229 230 return wrapped_fn.get_concrete_function 231 232 def trace_tf_function(defun_kwargs): 233 # Note: wrapper_helper will apply autograph based on context. 234 @eager_function.defun_with_attributes( 235 input_signature=structure.get_flat_tensor_specs( 236 self._input_structure), 237 autograph=False, 238 attributes=defun_kwargs) 239 def wrapped_fn(*args): # pylint: disable=missing-docstring 240 ret = wrapper_helper(*args) 241 ret = structure.to_tensor_list(self._output_structure, ret) 242 return [ops.convert_to_tensor(t) for t in ret] 243 244 return wrapped_fn.get_concrete_function 245 246 if use_legacy_function: 247 defun_kwargs.update({"func_name": func_name + "_" + str(ops.uid())}) 248 fn_factory = trace_legacy_function(defun_kwargs) 249 else: 250 defun_kwargs.update({"func_name": func_name}) 251 defun_kwargs.update({"_tf_data_function": True}) 252 if dataset_ops.DEBUG_MODE: 253 fn_factory = trace_py_function(defun_kwargs) 254 else: 255 if def_function.functions_run_eagerly(): 256 warnings.warn( 257 "Even though the `tf.config.experimental_run_functions_eagerly` " 258 "option is set, this option does not apply to tf.data functions. " 259 "To force eager execution of tf.data functions, please use " 260 "`tf.data.experimental.enable_debug_mode()`.") 261 fn_factory = trace_tf_function(defun_kwargs) 262 263 self._function = fn_factory() 264 # There is no graph to add in eager mode. 265 add_to_graph &= not context.executing_eagerly() 266 # There are some lifetime issues when a legacy function is not added to a 267 # out-living graph. It's already deprecated so de-prioritizing the fix. 268 add_to_graph |= use_legacy_function 269 if add_to_graph: 270 self._function.add_to_graph(ops.get_default_graph()) 271 272 if not use_legacy_function: 273 outer_graph_seed = ops.get_default_graph().seed 274 if outer_graph_seed and self._function.graph.seed == outer_graph_seed: 275 if self._function.graph._seed_used: 276 warnings.warn( 277 "Seed %s from outer graph might be getting used by function %s, " 278 "if the random op has not been provided any seed. Explicitly set " 279 "the seed in the function if this is not the intended behavior." % 280 (outer_graph_seed, func_name), 281 stacklevel=4) 282 283 @property 284 def output_structure(self): 285 return self._output_structure 286 287 @property 288 def output_classes(self): 289 return nest.map_structure( 290 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access 291 self._output_structure) 292 293 @property 294 def output_shapes(self): 295 return nest.map_structure( 296 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access 297 self._output_structure) 298 299 @property 300 def output_types(self): 301 return nest.map_structure( 302 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access 303 self._output_structure) 304 305 @property 306 def function(self): 307 return self._function 308