xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/script_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Script Language Operators."""
16
17# pylint: disable=g-bad-name
18import threading
19
20# Used by py_util.cc to get tracebacks.
21import traceback  # pylint: disable=unused-import
22import weakref
23
24import numpy as np
25
26from tensorflow.python.eager import backprop
27from tensorflow.python.eager import backprop_util
28from tensorflow.python.eager import context
29from tensorflow.python.eager import tape as tape_lib
30from tensorflow.python.framework import composite_tensor
31from tensorflow.python.framework import constant_op
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import func_graph
34from tensorflow.python.framework import function
35from tensorflow.python.framework import ops
36from tensorflow.python.framework import tensor_spec
37from tensorflow.python.framework import type_spec
38from tensorflow.python.lib.core import _pywrap_py_func
39from tensorflow.python.ops import gen_script_ops
40from tensorflow.python.ops import resource_variable_ops
41from tensorflow.python.util import compat
42from tensorflow.python.util import deprecation
43from tensorflow.python.util import dispatch
44from tensorflow.python.util import lazy_loader
45from tensorflow.python.util import nest
46from tensorflow.python.util import tf_inspect
47from tensorflow.python.util import variable_utils
48from tensorflow.python.util.tf_export import tf_export
49
50autograph = lazy_loader.LazyLoader(
51    "autograph", globals(),
52    "tensorflow.python.autograph.impl.api")
53
54
55# Map from EagerPyFunc token to tuple (tape, eager args, eager outputs);
56# used for differentiation.
57tape_cache = {}
58
59
60def _maybe_copy_to_context_device(tensor, device_name):
61  """Copy an EagerTensor to the current device if it's not on `device_name`."""
62  in_device = tensor.backing_device
63  if device_name == in_device:
64    return tensor
65  else:
66    # Note that EagerTensor._copy bypasses the placer and copies to the context
67    # device, which means e.g. int32 Tensors which would normally be forced onto
68    # the CPU can instead be placed on the GPU. This is necessary so that the
69    # PyFunc kernel always returns Tensors on the device it's executing on.
70    return tensor._copy()  # pylint: disable=protected-access
71
72
73class EagerFunc:
74  """A wrapper for a function owned by an EagerPyFunc."""
75
76  def __init__(self, func, Tout, is_grad_func):
77    """Constructs an EagerFunc.
78
79    Args:
80      func: The function to wrap.
81      Tout: A list of datatypes for the output; an empty list if the output is
82        None.
83      is_grad_func: Whether this EagerFunc is the gradient of another
84        EagerPyFunc.
85    """
86    self._func = func
87    self._out_dtypes = Tout
88    self._is_grad_func = is_grad_func
89    self._support_graph_mode_gradient = False
90
91  def set_support_graph_mode_gradient(self):
92    """Indicates the object shall support gradient ops.
93
94    This function is internally used by _EagerPyFuncGrad to support
95    graph mode gradient of EagerFunc via tf.gradient().
96    """
97    self._support_graph_mode_gradient = True
98
99  def _convert(self, value, dtype):
100    """Converts `value` to a tensor of type `dtype`, with error checking.
101
102    Args:
103      value: The tensor to convert.
104      dtype: The desired dtype.
105
106    Returns:
107      A tensor of type `dtype`, or a zeros tensor if value is None and
108      this function is in fact a gradient function.
109
110    Raises:
111      RuntimeError: if `value` is a variable.
112    """
113
114    if isinstance(value, resource_variable_ops.ResourceVariable):
115      raise RuntimeError(
116          "Attempting to return a variable from an eagerly executed py_func. "
117          "Only numeric data structures like Tensors or NumPy arrays should "
118          "be returned; to return the value of a variable, make sure to obtain "
119          "the Tensor backing it by calling `.read_value()` on the variable in "
120          f"question: {value}")
121    if value is None and self._is_grad_func:
122      # Gradient functions may legitimately return a list that contains
123      # both Tensors and Python Nones. Unfortunately this breaks the
124      # OpKernel, so for now we replace None objects with zeros, which is
125      # mathematically correct but will prevent short-circuiting gradient
126      # computations.
127      #
128      # TODO(akshayka): Make it possible to return a list of both Tensors and
129      # Nones from an EagerPyFunc.
130      return constant_op.constant(0.0, dtype=dtype)
131    return ops.convert_to_tensor(value, dtype=dtype)
132
133  def __call__(self, device, token, args):
134    """Calls `self._func` in eager mode, recording the tape if needed."""
135    use_tape_cache = (
136        self._support_graph_mode_gradient or tape_lib.could_possibly_record())
137
138    if use_tape_cache:
139      with backprop.GradientTape() as tape:
140        for tensor in args:
141          for t in nest.flatten(tensor):
142            if backprop_util.IsTrainable(t):
143              tape.watch(t)
144        outputs = self._call(device, args)
145      tape_cache[compat.as_bytes(token)] = (tape, args, outputs)
146    else:
147      outputs = self._call(device, args)
148
149    return outputs
150
151  def _call(self, device, args):
152    """Passes `args` to `self._func`, which is executed eagerly."""
153    with context.eager_mode():
154      ret = self._func(*args)
155      # copy the returned tensors to the PyFunc op's device if necessary.
156      device_name = device
157      if device_name is None:
158        # "None" here means "CPU", from the nullptr convention with C++ device
159        # pointers.
160        device_name = "/job:localhost/replica:0/task:0/device:CPU:0"
161      with ops.device(device):
162        if isinstance(ret, (tuple, list)):
163          outputs = [
164              _maybe_copy_to_context_device(self._convert(x, dtype=dtype),
165                                            device_name)
166              for (x, dtype) in zip(ret, self._out_dtypes)
167          ]
168        elif ret is None:
169          outputs = None
170        else:
171          outputs = _maybe_copy_to_context_device(
172              self._convert(ret, dtype=self._out_dtypes[0]), device_name)
173    return outputs
174
175
176class FuncRegistry:
177  """A helper class to keep track of registered py functions.
178
179  FuncRegistry keeps a map from unique tokens (string) to python
180  functions, which takes numpy arrays and outputs numpy arrays.
181  """
182
183  def __init__(self):
184    self._lock = threading.Lock()
185    self._unique_id = 0  # GUARDED_BY(self._lock)
186    # Only store weakrefs to the functions. The strong reference is stored in
187    # the graph.
188    self._funcs = weakref.WeakValueDictionary()
189
190  @property
191  def _ctx(self):
192    # N.B. This is needed to support calling py_func with GPU tensors,
193    # which must be transferred to CPU if used in any of the NumPy APIs.
194    context.ensure_initialized()
195    return context.context()._handle  # pylint: disable=protected-access
196
197  def insert(self, func):
198    """Registers `func` and returns a unique token for this entry."""
199    token = self._next_unique_token()
200    # Store a weakref to the function
201    self._funcs[token] = func
202    return token
203
204  def remove(self, token):
205    """Removes the registered function corresponding to `token`."""
206    self._funcs.pop(token, None)
207
208  def get(self, token, default=None):
209    """Gets the registered function corresponding to `token`."""
210    return self._funcs.get(token, default)
211
212  @staticmethod
213  def _convert(value, dtype=None):
214    """Converts an arg to numpy, avoiding dangerous string and unicode dtypes.
215
216    Numpy pads with zeros when using string and unicode dtypes if different
217    components of a tensor have different lengths.  This is bad: ignoring the
218    padding is wrong for text data, and removing the padding is wrong for binary
219    data.  To avoid this bug, we redo the conversion using an object dtype.
220    Additionally, we convert unicode strings to (byte-)strings for
221    compatibility.
222
223    Args:
224      value: Value to convert to a numpy array.
225      dtype: (Optional.) Desired NumPy type for the returned value.
226
227    Returns:
228      A numpy array.
229    """
230    result = np.asarray(value, dtype=dtype, order="C")
231    if result.dtype.char == "S" and result is not value:
232      return np.asarray(value, order="C", dtype=object)
233    elif result.dtype.char == "U" and result is not value:
234      value = np.vectorize(lambda x: x.encode("utf8"))(value)
235      return np.asarray(value, order="C", dtype=object)
236    elif result.dtype.char == "U":
237      return result.astype(np.bytes_)
238    else:
239      return result
240
241  def __call__(self, token, device, args):
242    """Calls the registered function for `token` with args.
243
244    Args:
245      token: A key into this `FuncRegistry` identifying which function to call.
246      device: Name of the device on which outputs of `token`'s corresponding
247        operation should be placed. Used iff the function registered for `token`
248        is an EagerPyFunc.
249      args: The arguments to pass to the function registered for `token`.
250
251    Returns:
252      The output of the function registered for `token`.
253
254    Raises:
255      ValueError: if no function is registered for `token`.
256    """
257    func = self.get(token, None)
258    if func is None:
259      raise ValueError(f"Could not find callback with key={token} in the "
260                       "registry.")
261    if isinstance(func, EagerFunc):
262      # NB: Different invocations of the same py_func will share the same
263      # token, and the entries they stash in the tape_cache will collide.
264      # In practice, when executing a graph, this should only happen if
265      # the py_func is in a while_loop whose iterations are run in parallel
266      # or if the graph is being driven by concurrent session.run() calls.
267      #
268      # TODO(akshayka): Key the tape cache in a thread-safe way.
269      return func(device, token, args)
270    else:
271      ret = func(*args)
272      # Strings seem to lead to a memory leak here if they're not wrapped in a
273      # list.
274      if isinstance(ret, bytes):
275        ret = [ret]
276      # Ensures that we return either a single numpy array or a list of numpy
277      # arrays.
278      if isinstance(ret, (tuple, list)):
279        return [self._convert(x) for x in ret]
280      else:
281        return self._convert(ret)
282
283  def size(self):
284    """Returns how many functions are currently registered."""
285    return len(self._funcs)
286
287  def _next_unique_token(self):
288    """Returns a unique token."""
289    with self._lock:
290      uid = self._unique_id
291      self._unique_id += 1
292    return "pyfunc_%d" % uid
293
294
295# Global registry for py functions.
296_py_funcs = FuncRegistry()
297
298_pywrap_py_func.initialize_py_trampoline(_py_funcs)
299
300
301def _internal_py_func(func,
302                      inp,
303                      Tout,
304                      stateful=None,
305                      use_eager_py_func=False,
306                      is_grad_func=False,
307                      name=None):
308  """See documentation for py_func and eager_py_func."""
309  if not callable(func):
310    raise ValueError(
311        f"Expected func to be callable. Received func={func} of type "
312        f"{type(func)}.")
313
314  original_func = func
315  func = autograph.do_not_convert(func)
316  inp = variable_utils.convert_variables_to_tensors(list(inp))
317
318  # Normalize Tout.
319  is_list_or_tuple = isinstance(Tout, (list, tuple))
320  Tout = Tout if is_list_or_tuple else [Tout]
321  Tout = [_as_dtype_or_type_spec(t) for t in Tout]
322
323  # Check if we need to handle CompositeTensor inputs or outputs.
324  handle_composite_tensors = (
325      use_eager_py_func and
326      (any(isinstance(v, composite_tensor.CompositeTensor) for v in inp) or
327       any(isinstance(t, type_spec.TypeSpec) for t in Tout)))
328  if handle_composite_tensors:
329    func, inp, Tout, out_structure = _wrap_for_composites(func, inp, Tout)
330
331  if use_eager_py_func:
332    func = EagerFunc(func, Tout, is_grad_func)
333
334  # Tying the registered function's lifetime with the current default graph is
335  # not reliable. For example, Estimator-based binaries may switch graphs in
336  # between model training end evaluation, via saved_model. Those binaries work
337  # because the original function is global, and break once the registered
338  # function is an anonymous lambda, like the one produced by do_not_convert.
339  # To avoid breaking those cases, we attach the wrapper to the original
340  # function so that their lifetime is connected.
341  # TODO(b/144286616): Remove this.
342  if tf_inspect.isfunction(original_func):
343    # Note: this check is needed because original_func may be a descriptor
344    # (https://docs.python.org/3/howto/descriptor.html)
345    # and we can't attach attributes to those.
346    original_func.ag_dnc_wrapper__ = func
347
348  token = _py_funcs.insert(func)
349  # We tie the registered function's lifetime with the current default graph,
350  # i.e., when the current graph is destroyed, we remove its py funcs.
351  graph = ops.get_default_graph()
352
353  while True:
354    current_graph = graph
355    if isinstance(graph, function._FuncGraph):  # pylint: disable=protected-access
356      graph = graph._outer_graph  # pylint: disable=protected-access
357    elif isinstance(graph, func_graph.FuncGraph):
358      graph = graph.outer_graph
359    if graph is current_graph:
360      break
361
362  # TODO(zhifengc): Consider adding a Graph method to collect
363  # `cleanup` objects in one of its member.
364  if not hasattr(graph, "_py_funcs_used_in_graph"):
365    graph._py_funcs_used_in_graph = []  # pylint: disable=protected-access
366
367  # Store a reference to the function in the graph to ensure it stays alive
368  # as long as the graph lives. When the graph is destroyed, the function
369  # is left to the garbage collector for destruction as well.
370  graph._py_funcs_used_in_graph.append(func)  # pylint: disable=protected-access
371
372  if use_eager_py_func:
373    result = gen_script_ops.eager_py_func(
374        input=inp,
375        token=token,
376        is_async=context.is_async(),
377        Tout=Tout,
378        name=name)
379  else:
380    if stateful:
381      result = gen_script_ops.py_func(
382          input=inp, token=token, Tout=Tout, name=name)
383    else:
384      result = gen_script_ops.py_func_stateless(
385          input=inp, token=token, Tout=Tout, name=name)
386
387  if handle_composite_tensors and Tout:
388    result = nest.pack_sequence_as(
389        out_structure, result, expand_composites=True)
390
391  return result if is_list_or_tuple else result[0]
392
393
394# TODO(akshayka): Implement higher-order derivatives.
395@ops.RegisterGradient("EagerPyFunc")
396def _EagerPyFuncGrad(op, *dy):
397  """Computes the gradient of an EagerPyFunc."""
398
399  token = op.get_attr("token")
400
401  def eagerly_executed_grad(*dy):
402    tape, eager_inputs, eager_outputs = tape_cache.pop(compat.as_bytes(token))
403    return tape.gradient(eager_outputs, eager_inputs, output_gradients=dy)
404
405  with ops.control_dependencies(op.outputs):
406    gradient_op = _internal_py_func(
407        func=eagerly_executed_grad,
408        inp=dy,
409        Tout=[tensor.dtype for tensor in op.inputs],
410        use_eager_py_func=True,
411        is_grad_func=True)
412
413  if not context.executing_eagerly():
414    # In graph mode, we find the func object from its token and
415    # notify the eager func object it needs to support the gradients.
416    func = _py_funcs.get(token.decode())
417    assert isinstance(func, EagerFunc), (
418        f"EagerPyFuncGrad called on a non-EagerFunc object: {func}.")
419    func.set_support_graph_mode_gradient()
420  return gradient_op
421
422
423@tf_export("py_function")
424@dispatch.add_dispatch_support
425def eager_py_func(func, inp, Tout, name=None):
426  """Wraps a python function into a TensorFlow op that executes it eagerly.
427
428  This function allows expressing computations in a TensorFlow graph as
429  Python functions. In particular, it wraps a Python function `func`
430  in a once-differentiable TensorFlow operation that executes it with eager
431  execution enabled. As a consequence, `tf.py_function` makes it
432  possible to express control flow using Python constructs (`if`, `while`,
433  `for`, etc.), instead of TensorFlow control flow constructs (`tf.cond`,
434  `tf.while_loop`). For example, you might use `tf.py_function` to
435  implement the log huber function:
436
437  ```python
438  def log_huber(x, m):
439    if tf.abs(x) <= m:
440      return x**2
441    else:
442      return m**2 * (1 - 2 * tf.math.log(m) + tf.math.log(x**2))
443
444  x = tf.constant(1.0)
445  m = tf.constant(2.0)
446
447  with tf.GradientTape() as t:
448    t.watch([x, m])
449    y = tf.py_function(func=log_huber, inp=[x, m], Tout=tf.float32)
450
451  dy_dx = t.gradient(y, x)
452  assert dy_dx.numpy() == 2.0
453  ```
454
455  You can also use `tf.py_function` to debug your models at runtime
456  using Python tools, i.e., you can isolate portions of your code that
457  you want to debug, wrap them in Python functions and insert `pdb` tracepoints
458  or print statements as desired, and wrap those functions in
459  `tf.py_function`.
460
461  For more information on eager execution, see the
462  [Eager guide](https://tensorflow.org/guide/eager).
463
464  `tf.py_function` is similar in spirit to `tf.compat.v1.py_func`, but unlike
465  the latter, the former lets you use TensorFlow operations in the wrapped
466  Python function. In particular, while `tf.compat.v1.py_func` only runs on CPUs
467  and wraps functions that take NumPy arrays as inputs and return NumPy arrays
468  as outputs, `tf.py_function` can be placed on GPUs and wraps functions
469  that take Tensors as inputs, execute TensorFlow operations in their bodies,
470  and return Tensors as outputs.
471
472  Note: We recommend to avoid using `tf.py_function` outside of prototyping
473  and experimentation due to the following known limitations:
474
475  * Calling `tf.py_function` will acquire the Python Global Interpreter Lock
476    (GIL) that allows only one thread to run at any point in time. This will
477    preclude efficient parallelization and distribution of the execution of the
478    program.
479
480  * The body of the function (i.e. `func`) will not be serialized in a
481    `GraphDef`. Therefore, you should not use this function if you need to
482    serialize your model and restore it in a different environment.
483
484  * The operation must run in the same address space as the Python program
485    that calls `tf.py_function()`. If you are using distributed
486    TensorFlow, you must run a `tf.distribute.Server` in the same process as the
487    program that calls `tf.py_function()` and you must pin the created
488    operation to a device in that server (e.g. using `with tf.device():`).
489
490  * Currently `tf.py_function` is not compatible with XLA. Calling
491    `tf.py_function` inside `tf.function(jit_compile=True)` will raise an
492    error.
493
494  Args:
495    func: A Python function that accepts `inp` as arguments, and returns a
496      value (or list of values) whose type is described by `Tout`.
497
498    inp: Input arguments for `func`.  A list whose elements are `Tensor`s or
499      `CompositeTensors` (such as `tf.RaggedTensor`); or a single `Tensor` or
500      `CompositeTensor`.
501
502    Tout: The type(s) of the value(s) returned by `func`.  One of the
503      following.
504
505      * If `func` returns a `Tensor` (or a value that can be converted to a
506        Tensor): the `tf.DType` for that value.
507      * If `func` returns a `CompositeTensor`: The `tf.TypeSpec` for that value.
508      * If `func` returns `None`: the empty list (`[]`).
509      * If `func` returns a list of `Tensor` and `CompositeTensor` values:
510        a corresponding list of `tf.DType`s and `tf.TypeSpec`s for each value.
511
512    name: A name for the operation (optional).
513
514  Returns:
515    The value(s) computed by `func`: a `Tensor`, `CompositeTensor`, or list of
516    `Tensor` and `CompositeTensor`; or an empty list if `func` returns `None`.
517  """
518  if ops.executing_eagerly_outside_functions():
519    with ops.device(context.context().host_address_space()):
520      return _internal_py_func(
521          func=func, inp=inp, Tout=Tout, use_eager_py_func=True, name=name)
522
523  return _internal_py_func(
524      func=func, inp=inp, Tout=Tout, use_eager_py_func=True, name=name)
525
526
527def py_func_common(func, inp, Tout, stateful=True, name=None):
528  """Wraps a python function and uses it as a TensorFlow op.
529
530  Given a python function `func`, which takes numpy arrays as its
531  arguments and returns numpy arrays as its outputs, wrap this function as an
532  operation in a TensorFlow graph. The following snippet constructs a simple
533  TensorFlow graph that invokes the `np.sinh()` NumPy function as a operation
534  in the graph:
535
536  ```python
537  def my_func(x):
538    # x will be a numpy array with the contents of the placeholder below
539    return np.sinh(x)
540  input = tf.compat.v1.placeholder(tf.float32)
541  y = tf.compat.v1.py_func(my_func, [input], tf.float32)
542  ```
543
544  **N.B.** The `tf.compat.v1.py_func()` operation has the following known
545  limitations:
546
547  * The body of the function (i.e. `func`) will not be serialized in a
548    `GraphDef`. Therefore, you should not use this function if you need to
549    serialize your model and restore it in a different environment.
550
551  * The operation must run in the same address space as the Python program
552    that calls `tf.compat.v1.py_func()`. If you are using distributed
553    TensorFlow, you
554    must run a `tf.distribute.Server` in the same process as the program that
555    calls
556    `tf.compat.v1.py_func()` and you must pin the created operation to a device
557    in that
558    server (e.g. using `with tf.device():`).
559
560  Note: It produces tensors of unknown shape and rank as shape inference
561    does not work on arbitrary Python code.
562    If you need the shape, you need to set it based on statically
563    available information.
564
565    E.g.
566    ```python
567    import tensorflow as tf
568    import numpy as np
569
570    def make_synthetic_data(i):
571        return np.cast[np.uint8](i) * np.ones([20,256,256,3],
572                dtype=np.float32) / 10.
573
574    def preprocess_fn(i):
575        ones = tf.py_function(make_synthetic_data,[i],tf.float32)
576        ones.set_shape(tf.TensorShape([None, None, None, None]))
577        ones = tf.image.resize(ones, [224,224])
578        return ones
579
580    ds = tf.data.Dataset.range(10)
581    ds = ds.map(preprocess_fn)
582    ```
583
584  Args:
585    func: A Python function, which accepts `ndarray` objects as arguments and
586      returns a list of `ndarray` objects (or a single `ndarray`). This function
587      must accept as many arguments as there are tensors in `inp`, and these
588      argument types will match the corresponding `tf.Tensor` objects in `inp`.
589      The returns `ndarray`s must match the number and types defined `Tout`.
590      Important Note: Input and output numpy `ndarray`s of `func` are not
591        guaranteed to be copies. In some cases their underlying memory will be
592        shared with the corresponding TensorFlow tensors. In-place modification
593        or storing `func` input or return values in python datastructures
594        without explicit (np.)copy can have non-deterministic consequences.
595    inp: A list of `Tensor` objects.
596    Tout: A list or tuple of tensorflow data types or a single tensorflow data
597      type if there is only one, indicating what `func` returns.
598    stateful: (Boolean.) If True, the function should be considered stateful. If
599      a function is stateless, when given the same input it will return the same
600      output and have no observable side effects. Optimizations such as common
601      subexpression elimination are only performed on stateless operations.
602    name: A name for the operation (optional).
603
604  Returns:
605    A list of `Tensor` or a single `Tensor` which `func` computes.
606
607  @compatibility(TF2)
608
609  This name was deprecated and removed in TF2, but `tf.numpy_function` is a
610  near-exact replacement, just drop the `stateful` argument (all
611  `tf.numpy_function` calls are considered stateful). It is compatible with
612  eager execution and `tf.function`.
613
614  `tf.py_function` is a close but not an exact replacement, passing TensorFlow
615  tensors to the wrapped function instead of NumPy arrays, which provides
616  gradients and can take advantage of accelerators.
617
618  Before:
619
620  >>> def fn_using_numpy(x):
621  ...   x[0] = 0.
622  ...   return x
623  >>> tf.compat.v1.py_func(fn_using_numpy, inp=[tf.constant([1., 2.])],
624  ...     Tout=tf.float32, stateful=False)
625  <tf.Tensor: shape=(2,), dtype=float32, numpy=array([0., 2.], dtype=float32)>
626
627  After:
628
629  >>> tf.numpy_function(fn_using_numpy, inp=[tf.constant([1., 2.])],
630  ...     Tout=tf.float32)
631  <tf.Tensor: shape=(2,), dtype=float32, numpy=array([0., 2.], dtype=float32)>
632
633  @end_compatibility
634
635  """
636  if context.executing_eagerly():
637    result = func(*[np.array(x) for x in inp])
638    result = nest.flatten(result)
639
640    result = [x if x is None else ops.convert_to_tensor(x) for x in result]
641    if len(result) == 1:
642      # Mimic the automatic unwrapping in graph-mode py_func
643      result, = result
644    return result
645
646  if ops.executing_eagerly_outside_functions():
647    with ops.device(context.context().host_address_space()):
648      return _internal_py_func(
649          func=func,
650          inp=inp,
651          Tout=Tout,
652          stateful=stateful,
653          use_eager_py_func=False,
654          name=name)
655
656  return _internal_py_func(
657      func=func,
658      inp=inp,
659      Tout=Tout,
660      stateful=stateful,
661      use_eager_py_func=False,
662      name=name)
663
664
665@deprecation.deprecated(
666    date=None,
667    instructions="""tf.py_func is deprecated in TF V2. Instead, there are two
668    options available in V2.
669    - tf.py_function takes a python function which manipulates tf eager
670    tensors instead of numpy arrays. It's easy to convert a tf eager tensor to
671    an ndarray (just call tensor.numpy()) but having access to eager tensors
672    means `tf.py_function`s can use accelerators such as GPUs as well as
673    being differentiable using a gradient tape.
674    - tf.numpy_function maintains the semantics of the deprecated tf.py_func
675    (it is not differentiable, and manipulates numpy arrays). It drops the
676    stateful argument making all functions stateful.
677    """)
678@tf_export(v1=["py_func"])
679@dispatch.add_dispatch_support
680def py_func(func, inp, Tout, stateful=True, name=None):
681  return py_func_common(func, inp, Tout, stateful, name=name)
682
683
684py_func.__doc__ = "%s" % py_func_common.__doc__
685
686
687@tf_export("numpy_function")
688@dispatch.add_dispatch_support
689def numpy_function(func, inp, Tout, stateful=True, name=None):
690  """Wraps a python function and uses it as a TensorFlow op.
691
692  Given a python function `func` wrap this function as an operation in a
693  TensorFlow function. `func` must take numpy arrays as its arguments and
694  return numpy arrays as its outputs.
695
696  The following example creates a TensorFlow graph with `np.sinh()` as an
697  operation in the graph:
698
699  >>> def my_numpy_func(x):
700  ...   # x will be a numpy array with the contents of the input to the
701  ...   # tf.function
702  ...   return np.sinh(x)
703  >>> @tf.function(input_signature=[tf.TensorSpec(None, tf.float32)])
704  ... def tf_function(input):
705  ...   y = tf.numpy_function(my_numpy_func, [input], tf.float32)
706  ...   return y * y
707  >>> tf_function(tf.constant(1.))
708  <tf.Tensor: shape=(), dtype=float32, numpy=1.3810978>
709
710  Comparison to `tf.py_function`:
711  `tf.py_function` and `tf.numpy_function` are very similar, except that
712  `tf.numpy_function` takes numpy arrays, and not `tf.Tensor`s. If you want the
713  function to contain `tf.Tensors`, and have any TensorFlow operations executed
714  in the function be differentiable, please use `tf.py_function`.
715
716  Note: We recommend to avoid using `tf.numpy_function` outside of
717  prototyping and experimentation due to the following known limitations:
718
719  * Calling `tf.numpy_function` will acquire the Python Global Interpreter Lock
720    (GIL) that allows only one thread to run at any point in time. This will
721    preclude efficient parallelization and distribution of the execution of the
722    program. Therefore, you are discouraged to use `tf.numpy_function` outside
723    of prototyping and experimentation.
724
725  * The body of the function (i.e. `func`) will not be serialized in a
726    `tf.SavedModel`. Therefore, you should not use this function if you need to
727    serialize your model and restore it in a different environment.
728
729  * The operation must run in the same address space as the Python program
730    that calls `tf.numpy_function()`. If you are using distributed
731    TensorFlow, you must run a `tf.distribute.Server` in the same process as the
732    program that calls `tf.numpy_function`  you must pin the created
733    operation to a device in that server (e.g. using `with tf.device():`).
734
735  * Currently `tf.numpy_function` is not compatible with XLA. Calling
736    `tf.numpy_function` inside `tf.function(jit_compile=True)` will raise an
737    error.
738
739  * Since the function takes numpy arrays, you cannot take gradients
740    through a numpy_function. If you require something that is differentiable,
741    please consider using tf.py_function.
742
743  Args:
744    func: A Python function, which accepts `numpy.ndarray` objects as arguments
745      and returns a list of `numpy.ndarray` objects (or a single
746      `numpy.ndarray`). This function must accept as many arguments as there are
747      tensors in `inp`, and these argument types will match the corresponding
748      `tf.Tensor` objects in `inp`. The returns `numpy.ndarray`s must match the
749      number and types defined `Tout`.
750      Important Note: Input and output `numpy.ndarray`s of `func` are not
751        guaranteed to be copies. In some cases their underlying memory will be
752        shared with the corresponding TensorFlow tensors. In-place modification
753        or storing `func` input or return values in python datastructures
754        without explicit (np.)copy can have non-deterministic consequences.
755    inp: A list of `tf.Tensor` objects.
756    Tout: A list or tuple of tensorflow data types or a single tensorflow data
757      type if there is only one, indicating what `func` returns.
758    stateful: (Boolean.) Setting this argument to False tells the runtime to
759      treat the function as stateless, which enables certain optimizations.
760      A function is stateless when given the same input it will return the
761      same output and have no side effects; its only purpose is to have a
762      return value.
763      The behavior for a stateful function with the `stateful` argument False
764      is undefined. In particular, caution should be taken when
765      mutating the input arguments as this is a stateful operation.
766    name: (Optional) A name for the operation.
767
768  Returns:
769    Single or list of `tf.Tensor` which `func` computes.
770  """
771  return py_func_common(func, inp, Tout, stateful=stateful, name=name)
772
773
774def _as_dtype_or_type_spec(t):
775  return t if isinstance(t, type_spec.TypeSpec) else dtypes.as_dtype(t)
776
777
778def _wrap_for_composites(func, inp, Tout):
779  """Wraps user inputs to support composite tensors for `py_function`.
780
781  1. Flattens `inp` to a list of Tensors (by flattening any composite tensors).
782  2. Creates a wrapper fuction for `func` that expects flat inputs and:
783     - Packs the inputs into the input structure expected by `func`.
784     - Calls `func` with the packed inputs.
785     - Checks that `func`'s output matches `Tout`.
786     - Flattens func`'s output to a list of Tensors (flattening any composite
787       tensors).
788
789  Args:
790    func: The function to wrap (`func` argument to `py_function`).
791    inp: The input arguments for func (`inp` argument to `py_function`).
792    Tout: The expected output types for func (`Tout` argument to `py_function).
793
794  Returns:
795    A tuple `(func, inp, Tout, out_structure)`, where `func` is the wrapped
796    function, `inp` is the flattened inputs, `Tout` is the list of expected
797    dtypes for the flattened outputs, and `out_structure` is the expected
798    output structure (which can be used to pack the output tensors).
799  """
800  in_structure = [
801      v if isinstance(v, composite_tensor.CompositeTensor) else 1 for v in inp
802  ]
803  inp = nest.flatten_up_to(in_structure, inp, expand_composites=True)
804  out_structure = Tout
805  Tout = [
806      v.dtype if isinstance(v, tensor_spec.TensorSpec) else v
807      for v in nest.flatten(Tout, expand_composites=True)
808  ]
809
810  def wrapped_func(*flat_inp):
811    structured_inp = nest.pack_sequence_as(
812        in_structure, flat_inp, expand_composites=True)
813    out = func(*structured_inp)
814    if not out_structure:
815      return []  # Ignore return value if none is requested/expected.
816    if not isinstance(out, (list, tuple)):
817      out = [out]  # func may return a single value instead of a list.
818    flat_out = []
819    for elt, expected_type in zip(out, out_structure):
820      if (isinstance(expected_type, type_spec.TypeSpec) and
821          not isinstance(expected_type, tensor_spec.TensorSpec)):
822        if not expected_type.is_compatible_with(elt):
823          # pylint: disable=protected-access
824          raise ValueError(
825              f"py_function: func={func} returned {out!r}, "
826              f"which did not match Tout={out_structure!r}.\nIn particular, "
827              f"{elt!r} is not compatible with {expected_type!r}.")
828        flat_out.extend(nest.flatten(elt, expand_composites=True))
829      else:
830        # Pro-actively check if the return value is a composite tensor when
831        # we expect a Tensor.  We would catch this later (when we call
832        # convert_to_tensor), but checking it here lets us give a better
833        # error message.
834        if isinstance(elt, composite_tensor.CompositeTensor):
835          raise ValueError(
836              f"py_function: func={func} returned {out!r}, "
837              f"which did not match Tout={out_structure!r}.\nIn particular, "
838              f"{elt!r} is not a Tensor.")
839        flat_out.append(elt)
840    return flat_out
841
842  return wrapped_func, inp, Tout, out_structure
843
844
845ops.NotDifferentiable("PyFunc")
846ops.NotDifferentiable("PyFuncStateless")
847