xref: /aosp_15_r20/external/tensorflow/tensorflow/python/autograph/utils/py_func.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Pyfunc creation utilities."""
16
17from collections import namedtuple
18
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import tensor_util
21from tensorflow.python.ops import script_ops
22
23
24class MatchDType(namedtuple('MatchDType', ('arg_number',))):
25  """Allows matching the dtype of an argument.
26
27  Used in conjunction with function calls. For example, MatchDType(0) will
28  match the DType of the first argument.
29  """
30
31  pass
32
33
34def wrap_py_func(f, return_dtypes, args, kwargs=None, use_dummy_return=False):
35  """Helper that wraps a callable to py_func.
36
37  The helper passes tensor arguments through the py_func interface. Non-tensor
38  arguments are allowed, and will be passed to f directly. Note that non-tensor
39  arguments are captured by f will not update every time the wrapper is
40  called (this is consistent with its argument list, which only includes
41  the tensor arguments). In general, it's safest not to reuse this wrapper.
42
43  Args:
44    f: Callable
45    return_dtypes: None, individual of tuple/list of DType or MatchDType, the
46        data type for each of f's return value(s). Set to None if f has no
47        return values or use_dummy_return is True. Use MatchDType to define a
48        dtype identical to that of `i`th argument (argument 0 is the first);
49        an argument must of Tensor type if it is to be used with MatchDType.
50    args: Positional arguments for f, as list or tuple.
51    kwargs: Keyword arguments for f, as dict with string keys. May be None.
52    use_dummy_return: If True, the function will return a dummy value of 1
53        and discard its actual return value.
54  Returns:
55    The return values of f converted to tensor.
56  Raises:
57    ValueError: if any of the arguments are incorrect.
58  """
59
60  if return_dtypes and use_dummy_return:
61    raise ValueError('if use_dummy_return is True, return_dtypes must be empty')
62
63  tensor_args = []
64  tensor_args_idx = {}
65
66  # Of the positional arguments, only grab the tensor ones to be passed through
67  # the py_func.
68  n_args = len(args)
69  arg_is_tensor = tuple(map(tensor_util.is_tf_type, args))
70  for i in range(n_args):
71    if arg_is_tensor[i]:
72      tensor_args_idx[i] = len(tensor_args)
73      tensor_args.append(args[i])
74
75  # We essentially take the tensor kwargs, if any, and add them to the list of
76  # positional arguments. The kwargs are then reconstructed inside the py_func.
77  #
78  # For example, if
79  #
80  #     args = [Tensor(1), 'foo']
81  #     kwargs = {'a': Tensor(2), 'b': 'bar'}
82  #
83  # Then
84  #
85  #     tensor_args = (Tensor(1), Tensor(2))
86  #     kwarg_keys = ('a', 'b')
87  if kwargs:
88    kwarg_keys = tuple(kwargs.keys())
89    kwarg_is_tensor = {k: tensor_util.is_tf_type(kwargs[k]) for k in kwarg_keys}
90    for k in kwarg_keys:
91      if kwarg_is_tensor[k]:
92        tensor_args_idx[k] = len(tensor_args)
93        tensor_args.append(kwargs[k])
94  else:
95    kwarg_keys = ()
96
97  # Set up return dtypes.
98  def match_arg_dtype(arg_number):
99    arg = args[arg_number]
100    if not arg_is_tensor[arg_number]:
101      raise ValueError(
102          'argument %d was used with MatchDType and must be a tf.Tensor, but '
103          'was %s instead' % (arg_number, type(arg)))
104    return arg.dtype
105
106  if return_dtypes:
107    if isinstance(return_dtypes, MatchDType):
108      return_dtypes = match_arg_dtype(return_dtypes.arg_number)
109    elif isinstance(return_dtypes, (list, tuple)):
110      return_dtypes = tuple(
111          match_arg_dtype(a.arg_number) if isinstance(a, MatchDType) else a
112          for a in return_dtypes)
113    else:
114      assert isinstance(return_dtypes, dtypes.DType)
115
116  def f_wrapper(*tensor_args):
117    f_args = tuple(tensor_args[tensor_args_idx[i]] if arg_is_tensor[i] else a
118                   for i, a in enumerate(args))
119    f_kwargs = {
120        k: tensor_args[tensor_args_idx[k]] if kwarg_is_tensor[k] else kwargs[k]
121        for i, k in enumerate(kwarg_keys)
122    }
123    retval = f(*f_args, **f_kwargs)
124    return 1 if use_dummy_return else retval
125
126  if use_dummy_return:
127    return_dtypes = dtypes.int32
128  return script_ops.eager_py_func(f_wrapper, tensor_args, return_dtypes)
129