xref: /aosp_15_r20/external/tensorflow/tensorflow/python/eager/execute.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functions called by the generated code to execute an eager-mode op."""
16
17from google.protobuf import text_format
18from tensorflow.core.framework import tensor_pb2
19from tensorflow.python import pywrap_tfe
20from tensorflow.python.eager import core
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import tensor_shape
24from tensorflow.python.util import compat
25
26
27def quick_execute(op_name, num_outputs, inputs, attrs, ctx, name=None):
28  """Execute a TensorFlow operation.
29
30  Args:
31    op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to
32      execute.
33    num_outputs: The number of outputs of the operation to fetch. (Explicitly
34      provided instead of being inferred for performance reasons).
35    inputs: A list of inputs to the operation. Each entry should be a Tensor, or
36      a value which can be passed to the Tensor constructor to create one.
37    attrs: A tuple with alternating string attr names and attr values for this
38      operation.
39    ctx: The value of context.context().
40    name: Customized name for the operation.
41
42  Returns:
43    List of output Tensor objects. The list is empty if there are no outputs
44
45  Raises:
46    An exception on error.
47  """
48  device_name = ctx.device_name
49  # pylint: disable=protected-access
50  try:
51    ctx.ensure_initialized()
52    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
53                                        inputs, attrs, num_outputs)
54  except core._NotOkStatusException as e:
55    if name is not None:
56      e.message += " name: " + name
57    raise core._status_to_exception(e) from None
58  except TypeError as e:
59    keras_symbolic_tensors = [
60        x for x in inputs if ops._is_keras_symbolic_tensor(x)
61    ]
62    if keras_symbolic_tensors:
63      raise core._SymbolicException(
64          "Inputs to eager execution function cannot be Keras symbolic "
65          "tensors, but found {}".format(keras_symbolic_tensors))
66    raise e
67  # pylint: enable=protected-access
68  return tensors
69
70
71def execute_with_cancellation(op_name,
72                              num_outputs,
73                              inputs,
74                              attrs,
75                              ctx,
76                              cancellation_manager,
77                              name=None):
78  """Execute a TensorFlow operation.
79
80  Args:
81    op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to
82      execute.
83    num_outputs: The number of outputs of the operation to fetch. (Explicitly
84      provided instead of being inferred for performance reasons).
85    inputs: A list of inputs to the operation. Each entry should be a Tensor, or
86      a value which can be passed to the Tensor constructor to create one.
87    attrs: A tuple with alternating string attr names and attr values for this
88      operation.
89    ctx: The value of context.context().
90    cancellation_manager: a `CancellationManager` object that can be used to
91      cancel the operation.
92    name: Customized name for the operation.
93
94  Returns:
95    List of output Tensor objects. The list is empty if there are no outputs
96
97  Raises:
98    An exception on error.
99  """
100  device_name = ctx.device_name
101  # pylint: disable=protected-access
102  try:
103    ctx.ensure_initialized()
104    tensors = pywrap_tfe.TFE_Py_ExecuteCancelable(ctx._handle, device_name,
105                                                  op_name, inputs, attrs,
106                                                  cancellation_manager._impl,
107                                                  num_outputs)
108  except core._NotOkStatusException as e:
109    if name is not None:
110      e.message += " name: " + name
111    raise core._status_to_exception(e) from None
112  except TypeError as e:
113    keras_symbolic_tensors = [
114        x for x in inputs if ops._is_keras_symbolic_tensor(x)
115    ]
116    if keras_symbolic_tensors:
117      raise core._SymbolicException(
118          "Inputs to eager execution function cannot be Keras symbolic "
119          "tensors, but found {}".format(keras_symbolic_tensors))
120    raise e
121  # pylint: enable=protected-access
122  return tensors
123
124
125def execute_with_callbacks(op_name, num_outputs, inputs, attrs, ctx, name=None):
126  """Monkey-patch to execute to enable execution callbacks."""
127  tensors = quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
128  for callback in ctx.op_callbacks:
129    callback(op_name, tuple(inputs), attrs, tensors, name)
130
131  return tensors
132
133
134execute = quick_execute
135
136
137def must_record_gradient():
138  """Import backprop if you want gradients recorded."""
139  return False
140
141
142def record_gradient(unused_op_name, unused_inputs, unused_attrs,
143                    unused_outputs):
144  """Import backprop if you want gradients recorded."""
145  pass
146
147
148def make_float(v, arg_name):
149  if not isinstance(v, compat.real_types):
150    raise TypeError("Expected float for argument '%s' not %s." %
151                    (arg_name, repr(v)))
152  return float(v)
153
154
155def make_int(v, arg_name):
156  if isinstance(v, str):
157    raise TypeError("Expected int for argument '%s' not %s." %
158                    (arg_name, repr(v)))
159  try:
160    return int(v)
161  except (ValueError, TypeError):
162    raise TypeError("Expected int for argument '%s' not %s." %
163                    (arg_name, repr(v)))
164
165
166def make_str(v, arg_name):
167  if not isinstance(v, compat.bytes_or_text_types):
168    raise TypeError("Expected string for argument '%s' not %s." %
169                    (arg_name, repr(v)))
170  return compat.as_bytes(v)  # Convert unicode strings to bytes.
171
172
173def make_bool(v, arg_name):
174  if not isinstance(v, bool):
175    raise TypeError("Expected bool for argument '%s' not %s." %
176                    (arg_name, repr(v)))
177  return v
178
179
180def make_type(v, arg_name):
181  try:
182    v = dtypes.as_dtype(v).base_dtype
183  except TypeError:
184    raise TypeError("Expected DataType for argument '%s' not %s." %
185                    (arg_name, repr(v)))
186  i = v.as_datatype_enum
187  return i
188
189
190def make_shape(v, arg_name):
191  """Convert v into a list."""
192  # Args:
193  #   v: A TensorShapeProto, a list of ints, or a tensor_shape.TensorShape.
194  #   arg_name: String, for error messages.
195
196  # Returns:
197  #   None if the rank is unknown, otherwise a list of ints (or Nones in the
198  #   position where the dimension is unknown).
199  try:
200    shape = tensor_shape.as_shape(v)
201  except TypeError as e:
202    raise TypeError("Error converting %s to a TensorShape: %s." % (arg_name, e))
203  except ValueError as e:
204    raise ValueError("Error converting %s to a TensorShape: %s." %
205                     (arg_name, e))
206  if shape.ndims is None:
207    return None
208  else:
209    return shape.as_list()
210
211
212def make_tensor(v, arg_name):
213  """Ensure v is a TensorProto."""
214  if isinstance(v, tensor_pb2.TensorProto):
215    return v
216  elif isinstance(v, str):
217    pb = tensor_pb2.TensorProto()
218    text_format.Merge(v, pb)
219    return pb
220  raise TypeError(
221      "Don't know how to convert %s to a TensorProto for argument '%s'." %
222      (repr(v), arg_name))
223
224
225def args_to_matching_eager(l, ctx, allowed_dtypes, default_dtype=None):
226  """Convert sequence `l` to eager same-type Tensors."""
227  if (not l) and (default_dtype is not None):
228    return default_dtype, []  # List is empty; assume default dtype.
229  EagerTensor = ops.EagerTensor  # pylint: disable=invalid-name
230  for x in l:
231    if not isinstance(x, EagerTensor):
232      break
233  else:  # note: intentional for-else
234    return l[0]._datatype_enum(), l  # pylint: disable=protected-access
235
236  # Is some input already a Tensor with a dtype?
237  dtype = None
238  for t in l:
239    if isinstance(t, EagerTensor):
240      dtype = t.dtype
241      break
242
243  if dtype is None:
244    # Infer a dtype based on the first value, and use that dtype for the
245    # remaining values.
246
247    ret = []
248    for t in l:
249      tensor = None
250      # First see if we can get a valid dtype with the default conversion
251      # and see if it matches an allowed dtypes. Some ops like ConcatV2 may
252      # not list allowed dtypes, in which case we should skip this.
253      if dtype is None and allowed_dtypes:
254        tensor = ops.convert_to_tensor(t, ctx=ctx)
255        # If we did not match an allowed dtype, try again with the default
256        # dtype. This could be because we have an empty tensor and thus we
257        # picked the wrong type.
258        if tensor.dtype not in allowed_dtypes:
259          tensor = None
260
261      if tensor is None:
262        tensor = ops.convert_to_tensor(
263            t, dtype, preferred_dtype=default_dtype, ctx=ctx)
264
265      ret.append(tensor)
266      if dtype is None:
267        dtype = tensor.dtype
268  else:
269    ret = [ops.convert_to_tensor(t, dtype, ctx=ctx) for t in l]
270
271  # TODO(slebedev): consider removing this as it leaks a Keras concept.
272  # pylint: disable=protected-access
273  keras_symbolic_tensors = [x for x in ret if ops._is_keras_symbolic_tensor(x)]
274  if keras_symbolic_tensors:
275    raise core._SymbolicException(
276        "Using symbolic output of a Keras layer during eager execution "
277        "{}".format(keras_symbolic_tensors))
278  # pylint: enable=protected-access
279  return dtype.as_datatype_enum, ret
280
281
282def convert_to_mixed_eager_tensors(values, ctx):
283  v = [ops.convert_to_tensor(t, ctx=ctx) for t in values]
284  types = [t._datatype_enum() for t in v]  # pylint: disable=protected-access
285  return types, v
286
287
288def args_to_mixed_eager_tensors(lists, ctx):
289  """Converts a list of same-length lists of values to eager tensors."""
290  assert len(lists) > 1
291
292  # Generate an error if len(lists[i]) is not the same for all i.
293  lists_ret = [[]]
294  for l in lists[1:]:
295    if len(l) != len(lists[0]):
296      raise ValueError(
297          "Expected list arguments to be the same length: %d != %d (%r vs. %r)."
298          % (len(lists[0]), len(l), lists[0], l))
299    lists_ret.append([])
300
301  # Convert the first element of each list first, then the second element, etc.
302  types = []
303  for i in range(len(lists[0])):
304    dtype = None
305    # If any list has a Tensor, use that dtype
306    for l in lists:
307      if isinstance(l[i], ops.EagerTensor):
308        dtype = l[i].dtype
309        break
310    if dtype is None:
311      # Convert the first one and use its dtype.
312      lists_ret[0].append(ops.convert_to_tensor(lists[0][i], ctx=ctx))
313      dtype = lists_ret[0][i].dtype
314      for j in range(1, len(lists)):
315        lists_ret[j].append(
316            ops.convert_to_tensor(lists[j][i], dtype=dtype, ctx=ctx))
317    else:
318      # Convert everything to the found dtype.
319      for j in range(len(lists)):
320        lists_ret[j].append(
321            ops.convert_to_tensor(lists[j][i], dtype=dtype, ctx=ctx))
322    types.append(dtype.as_datatype_enum)
323  return types, lists_ret
324