xref: /aosp_15_r20/external/tensorflow/tensorflow/python/data/ops/structured_function.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for managing tf.data user-defined functions."""
16
17import warnings
18
19from tensorflow.python.data.util import nest
20from tensorflow.python.data.util import structure
21from tensorflow.python.eager import context
22from tensorflow.python.eager import def_function
23from tensorflow.python.eager import function as eager_function
24
25from tensorflow.python.framework import function
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import script_ops
28from tensorflow.python.util import function_utils
29from tensorflow.python.util import lazy_loader
30from tensorflow.python.util import variable_utils
31
32autograph = lazy_loader.LazyLoader(
33    "autograph", globals(),
34    "tensorflow.python.autograph.impl.api")
35# TODO(mdan): Create a public API for this.
36autograph_ctx = lazy_loader.LazyLoader(
37    "autograph_ctx", globals(),
38    "tensorflow.python.autograph.core.ag_ctx")
39dataset_ops = lazy_loader.LazyLoader(
40    "dataset_ops", globals(),
41    "tensorflow.python.data.ops.dataset_ops")
42
43
44def _should_pack(arg):
45  """Determines whether the caller needs to pack the argument in a tuple.
46
47  If user-defined function returns a list of tensors, `nest.flatten()` and
48  `ops.convert_to_tensor()` and would conspire to attempt to stack those tensors
49  into a single tensor because the tf.data version of `nest.flatten()` does
50  not recurse into lists. Since it is more likely that the list arose from
51  returning the result of an operation (such as `tf.numpy_function()`) that
52  returns a list of not-necessarily-stackable tensors, we treat the returned
53  value as a `tuple` instead. A user wishing to pack the return value into a
54  single tensor can use an explicit `tf.stack()` before returning.
55
56  Args:
57    arg: argument to check
58
59  Returns:
60    Indication of whether the caller needs to pack the argument in a tuple.
61  """
62  return isinstance(arg, list)
63
64
65def _should_unpack(arg):
66  """Determines whether the caller needs to unpack the argument from a tuple.
67
68  Args:
69    arg: argument to check
70
71  Returns:
72    Indication of whether the caller needs to unpack the argument from a tuple.
73  """
74  return type(arg) is tuple  # pylint: disable=unidiomatic-typecheck
75
76
77class StructuredFunctionWrapper():
78  """A function wrapper that supports structured arguments and return values."""
79
80  def __init__(self,
81               func,
82               transformation_name,
83               dataset=None,
84               input_classes=None,
85               input_shapes=None,
86               input_types=None,
87               input_structure=None,
88               add_to_graph=True,
89               use_legacy_function=False,
90               defun_kwargs=None):
91    """Creates a new `StructuredFunctionWrapper` for the given function.
92
93    Args:
94      func: A function from a (nested) structure to another (nested) structure.
95      transformation_name: Human-readable name of the transformation in which
96        this function is being instantiated, for error messages.
97      dataset: (Optional.) A `tf.data.Dataset`. If given, the structure of this
98        dataset will be assumed as the structure for `func` arguments; otherwise
99        `input_classes`, `input_shapes`, and `input_types` must be defined.
100      input_classes: (Optional.) A (nested) structure of `type`. If given, this
101        argument defines the Python types for `func` arguments.
102      input_shapes: (Optional.) A (nested) structure of `tf.TensorShape`. If
103        given, this argument defines the shapes and structure for `func`
104        arguments.
105      input_types: (Optional.) A (nested) structure of `tf.DType`. If given,
106        this argument defines the element types and structure for `func`
107        arguments.
108      input_structure: (Optional.) A `Structure` object. If given, this argument
109        defines the element types and structure for `func` arguments.
110      add_to_graph: (Optional.) If `True`, the function will be added to the
111        default graph, if it exists.
112      use_legacy_function: (Optional.) A boolean that determines whether the
113        function be created using `tensorflow.python.eager.function.defun`
114        (default behavior) or `tensorflow.python.framework.function.Defun`
115        (legacy behavior).
116      defun_kwargs: (Optional.) A dictionary mapping string argument names to
117        values. If supplied, will be passed to `function` as keyword arguments.
118
119    Raises:
120      ValueError: If an invalid combination of `dataset`, `input_classes`,
121        `input_shapes`, and `input_types` is passed.
122    """
123    # pylint: disable=protected-access
124    if input_structure is None:
125      if dataset is None:
126        if input_classes is None or input_shapes is None or input_types is None:
127          raise ValueError("Either `dataset`, `input_structure` or all of "
128                           "`input_classes`, `input_shapes`, and `input_types` "
129                           "must be specified.")
130        self._input_structure = structure.convert_legacy_structure(
131            input_types, input_shapes, input_classes)
132      else:
133        if not (input_classes is None and input_shapes is None and
134                input_types is None):
135          raise ValueError("Either `dataset`, `input_structure` or all of "
136                           "`input_classes`, `input_shapes`, and `input_types` "
137                           "must be specified.")
138        self._input_structure = dataset.element_spec
139    else:
140      if not (dataset is None and input_classes is None and
141              input_shapes is None and input_types is None):
142        raise ValueError("Either `dataset`, `input_structure`, or all of "
143                         "`input_classes`, `input_shapes`, and `input_types` "
144                         "must be specified.")
145      self._input_structure = input_structure
146
147    self._func = func
148
149    if defun_kwargs is None:
150      defun_kwargs = {}
151
152    readable_transformation_name = transformation_name.replace(
153        ".", "_")[:-2] if len(transformation_name) > 2 else ""
154
155    func_name = "_".join(
156        [readable_transformation_name,
157         function_utils.get_func_name(func)])
158    # Sanitize function name to remove symbols that interfere with graph
159    # construction.
160    for symbol in ["<", ">", "\\", "'", " "]:
161      func_name = func_name.replace(symbol, "")
162
163    ag_ctx = autograph_ctx.control_status_ctx()
164
165    def wrapper_helper(*args):
166      """Wrapper for passing nested structures to and from tf.data functions."""
167      nested_args = structure.from_compatible_tensor_list(
168          self._input_structure, args)
169      if not _should_unpack(nested_args):
170        nested_args = (nested_args,)
171      ret = autograph.tf_convert(self._func, ag_ctx)(*nested_args)
172      ret = variable_utils.convert_variables_to_tensors(ret)
173      if _should_pack(ret):
174        ret = tuple(ret)
175
176      try:
177        self._output_structure = structure.type_spec_from_value(ret)
178      except (ValueError, TypeError) as e:
179        raise TypeError(f"Unsupported return value from function passed to "
180                        f"{transformation_name}: {ret}.") from e
181      return ret
182
183    def trace_legacy_function(defun_kwargs):
184
185      @function.Defun(*structure.get_flat_tensor_types(self._input_structure),
186                      **defun_kwargs)
187      def wrapped_fn(*args):
188        ret = wrapper_helper(*args)
189        return structure.to_tensor_list(self._output_structure, ret)
190
191      return lambda: wrapped_fn
192
193    def trace_py_function(defun_kwargs):
194      # First we trace the function to infer the output structure.
195      @eager_function.defun_with_attributes(
196          input_signature=structure.get_flat_tensor_specs(
197              self._input_structure),
198          autograph=False,
199          attributes=defun_kwargs)
200      def unused(*args):  # pylint: disable=missing-docstring,unused-variable
201        ret = wrapper_helper(*args)
202        ret = structure.to_tensor_list(self._output_structure, ret)
203        return [ops.convert_to_tensor(t) for t in ret]
204
205      _ = unused.get_concrete_function()
206
207      def py_function_wrapper(*args):
208        nested_args = structure.from_compatible_tensor_list(
209            self._input_structure, args)
210        if not _should_unpack(nested_args):
211          nested_args = (nested_args,)
212        ret = self._func(*nested_args)
213        if _should_pack(ret):
214          ret = tuple(ret)
215        ret = structure.to_tensor_list(self._output_structure, ret)
216        return [ops.convert_to_tensor(t) for t in ret]
217
218      # Next we trace the function wrapped in `eager_py_func` to force eager
219      # execution.
220      @eager_function.defun_with_attributes(
221          input_signature=structure.get_flat_tensor_specs(
222              self._input_structure),
223          autograph=False,
224          attributes=defun_kwargs)
225      def wrapped_fn(*args):  # pylint: disable=missing-docstring
226        return script_ops.eager_py_func(
227            py_function_wrapper, args,
228            structure.get_flat_tensor_types(self._output_structure))
229
230      return wrapped_fn.get_concrete_function
231
232    def trace_tf_function(defun_kwargs):
233      # Note: wrapper_helper will apply autograph based on context.
234      @eager_function.defun_with_attributes(
235          input_signature=structure.get_flat_tensor_specs(
236              self._input_structure),
237          autograph=False,
238          attributes=defun_kwargs)
239      def wrapped_fn(*args):  # pylint: disable=missing-docstring
240        ret = wrapper_helper(*args)
241        ret = structure.to_tensor_list(self._output_structure, ret)
242        return [ops.convert_to_tensor(t) for t in ret]
243
244      return wrapped_fn.get_concrete_function
245
246    if use_legacy_function:
247      defun_kwargs.update({"func_name": func_name + "_" + str(ops.uid())})
248      fn_factory = trace_legacy_function(defun_kwargs)
249    else:
250      defun_kwargs.update({"func_name": func_name})
251      defun_kwargs.update({"_tf_data_function": True})
252      if dataset_ops.DEBUG_MODE:
253        fn_factory = trace_py_function(defun_kwargs)
254      else:
255        if def_function.functions_run_eagerly():
256          warnings.warn(
257              "Even though the `tf.config.experimental_run_functions_eagerly` "
258              "option is set, this option does not apply to tf.data functions. "
259              "To force eager execution of tf.data functions, please use "
260              "`tf.data.experimental.enable_debug_mode()`.")
261        fn_factory = trace_tf_function(defun_kwargs)
262
263    self._function = fn_factory()
264    # There is no graph to add in eager mode.
265    add_to_graph &= not context.executing_eagerly()
266    # There are some lifetime issues when a legacy function is not added to a
267    # out-living graph. It's already deprecated so de-prioritizing the fix.
268    add_to_graph |= use_legacy_function
269    if add_to_graph:
270      self._function.add_to_graph(ops.get_default_graph())
271
272    if not use_legacy_function:
273      outer_graph_seed = ops.get_default_graph().seed
274      if outer_graph_seed and self._function.graph.seed == outer_graph_seed:
275        if self._function.graph._seed_used:
276          warnings.warn(
277              "Seed %s from outer graph might be getting used by function %s, "
278              "if the random op has not been provided any seed. Explicitly set "
279              "the seed in the function if this is not the intended behavior." %
280              (outer_graph_seed, func_name),
281              stacklevel=4)
282
283  @property
284  def output_structure(self):
285    return self._output_structure
286
287  @property
288  def output_classes(self):
289    return nest.map_structure(
290        lambda component_spec: component_spec._to_legacy_output_classes(),  # pylint: disable=protected-access
291        self._output_structure)
292
293  @property
294  def output_shapes(self):
295    return nest.map_structure(
296        lambda component_spec: component_spec._to_legacy_output_shapes(),  # pylint: disable=protected-access
297        self._output_structure)
298
299  @property
300  def output_types(self):
301    return nest.map_structure(
302        lambda component_spec: component_spec._to_legacy_output_types(),  # pylint: disable=protected-access
303        self._output_structure)
304
305  @property
306  def function(self):
307    return self._function
308