1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Script Language Operators.""" 16 17# pylint: disable=g-bad-name 18import threading 19 20# Used by py_util.cc to get tracebacks. 21import traceback # pylint: disable=unused-import 22import weakref 23 24import numpy as np 25 26from tensorflow.python.eager import backprop 27from tensorflow.python.eager import backprop_util 28from tensorflow.python.eager import context 29from tensorflow.python.eager import tape as tape_lib 30from tensorflow.python.framework import composite_tensor 31from tensorflow.python.framework import constant_op 32from tensorflow.python.framework import dtypes 33from tensorflow.python.framework import func_graph 34from tensorflow.python.framework import function 35from tensorflow.python.framework import ops 36from tensorflow.python.framework import tensor_spec 37from tensorflow.python.framework import type_spec 38from tensorflow.python.lib.core import _pywrap_py_func 39from tensorflow.python.ops import gen_script_ops 40from tensorflow.python.ops import resource_variable_ops 41from tensorflow.python.util import compat 42from tensorflow.python.util import deprecation 43from tensorflow.python.util import dispatch 44from tensorflow.python.util import lazy_loader 45from tensorflow.python.util import nest 46from tensorflow.python.util import tf_inspect 47from tensorflow.python.util import variable_utils 48from tensorflow.python.util.tf_export import tf_export 49 50autograph = lazy_loader.LazyLoader( 51 "autograph", globals(), 52 "tensorflow.python.autograph.impl.api") 53 54 55# Map from EagerPyFunc token to tuple (tape, eager args, eager outputs); 56# used for differentiation. 57tape_cache = {} 58 59 60def _maybe_copy_to_context_device(tensor, device_name): 61 """Copy an EagerTensor to the current device if it's not on `device_name`.""" 62 in_device = tensor.backing_device 63 if device_name == in_device: 64 return tensor 65 else: 66 # Note that EagerTensor._copy bypasses the placer and copies to the context 67 # device, which means e.g. int32 Tensors which would normally be forced onto 68 # the CPU can instead be placed on the GPU. This is necessary so that the 69 # PyFunc kernel always returns Tensors on the device it's executing on. 70 return tensor._copy() # pylint: disable=protected-access 71 72 73class EagerFunc: 74 """A wrapper for a function owned by an EagerPyFunc.""" 75 76 def __init__(self, func, Tout, is_grad_func): 77 """Constructs an EagerFunc. 78 79 Args: 80 func: The function to wrap. 81 Tout: A list of datatypes for the output; an empty list if the output is 82 None. 83 is_grad_func: Whether this EagerFunc is the gradient of another 84 EagerPyFunc. 85 """ 86 self._func = func 87 self._out_dtypes = Tout 88 self._is_grad_func = is_grad_func 89 self._support_graph_mode_gradient = False 90 91 def set_support_graph_mode_gradient(self): 92 """Indicates the object shall support gradient ops. 93 94 This function is internally used by _EagerPyFuncGrad to support 95 graph mode gradient of EagerFunc via tf.gradient(). 96 """ 97 self._support_graph_mode_gradient = True 98 99 def _convert(self, value, dtype): 100 """Converts `value` to a tensor of type `dtype`, with error checking. 101 102 Args: 103 value: The tensor to convert. 104 dtype: The desired dtype. 105 106 Returns: 107 A tensor of type `dtype`, or a zeros tensor if value is None and 108 this function is in fact a gradient function. 109 110 Raises: 111 RuntimeError: if `value` is a variable. 112 """ 113 114 if isinstance(value, resource_variable_ops.ResourceVariable): 115 raise RuntimeError( 116 "Attempting to return a variable from an eagerly executed py_func. " 117 "Only numeric data structures like Tensors or NumPy arrays should " 118 "be returned; to return the value of a variable, make sure to obtain " 119 "the Tensor backing it by calling `.read_value()` on the variable in " 120 f"question: {value}") 121 if value is None and self._is_grad_func: 122 # Gradient functions may legitimately return a list that contains 123 # both Tensors and Python Nones. Unfortunately this breaks the 124 # OpKernel, so for now we replace None objects with zeros, which is 125 # mathematically correct but will prevent short-circuiting gradient 126 # computations. 127 # 128 # TODO(akshayka): Make it possible to return a list of both Tensors and 129 # Nones from an EagerPyFunc. 130 return constant_op.constant(0.0, dtype=dtype) 131 return ops.convert_to_tensor(value, dtype=dtype) 132 133 def __call__(self, device, token, args): 134 """Calls `self._func` in eager mode, recording the tape if needed.""" 135 use_tape_cache = ( 136 self._support_graph_mode_gradient or tape_lib.could_possibly_record()) 137 138 if use_tape_cache: 139 with backprop.GradientTape() as tape: 140 for tensor in args: 141 for t in nest.flatten(tensor): 142 if backprop_util.IsTrainable(t): 143 tape.watch(t) 144 outputs = self._call(device, args) 145 tape_cache[compat.as_bytes(token)] = (tape, args, outputs) 146 else: 147 outputs = self._call(device, args) 148 149 return outputs 150 151 def _call(self, device, args): 152 """Passes `args` to `self._func`, which is executed eagerly.""" 153 with context.eager_mode(): 154 ret = self._func(*args) 155 # copy the returned tensors to the PyFunc op's device if necessary. 156 device_name = device 157 if device_name is None: 158 # "None" here means "CPU", from the nullptr convention with C++ device 159 # pointers. 160 device_name = "/job:localhost/replica:0/task:0/device:CPU:0" 161 with ops.device(device): 162 if isinstance(ret, (tuple, list)): 163 outputs = [ 164 _maybe_copy_to_context_device(self._convert(x, dtype=dtype), 165 device_name) 166 for (x, dtype) in zip(ret, self._out_dtypes) 167 ] 168 elif ret is None: 169 outputs = None 170 else: 171 outputs = _maybe_copy_to_context_device( 172 self._convert(ret, dtype=self._out_dtypes[0]), device_name) 173 return outputs 174 175 176class FuncRegistry: 177 """A helper class to keep track of registered py functions. 178 179 FuncRegistry keeps a map from unique tokens (string) to python 180 functions, which takes numpy arrays and outputs numpy arrays. 181 """ 182 183 def __init__(self): 184 self._lock = threading.Lock() 185 self._unique_id = 0 # GUARDED_BY(self._lock) 186 # Only store weakrefs to the functions. The strong reference is stored in 187 # the graph. 188 self._funcs = weakref.WeakValueDictionary() 189 190 @property 191 def _ctx(self): 192 # N.B. This is needed to support calling py_func with GPU tensors, 193 # which must be transferred to CPU if used in any of the NumPy APIs. 194 context.ensure_initialized() 195 return context.context()._handle # pylint: disable=protected-access 196 197 def insert(self, func): 198 """Registers `func` and returns a unique token for this entry.""" 199 token = self._next_unique_token() 200 # Store a weakref to the function 201 self._funcs[token] = func 202 return token 203 204 def remove(self, token): 205 """Removes the registered function corresponding to `token`.""" 206 self._funcs.pop(token, None) 207 208 def get(self, token, default=None): 209 """Gets the registered function corresponding to `token`.""" 210 return self._funcs.get(token, default) 211 212 @staticmethod 213 def _convert(value, dtype=None): 214 """Converts an arg to numpy, avoiding dangerous string and unicode dtypes. 215 216 Numpy pads with zeros when using string and unicode dtypes if different 217 components of a tensor have different lengths. This is bad: ignoring the 218 padding is wrong for text data, and removing the padding is wrong for binary 219 data. To avoid this bug, we redo the conversion using an object dtype. 220 Additionally, we convert unicode strings to (byte-)strings for 221 compatibility. 222 223 Args: 224 value: Value to convert to a numpy array. 225 dtype: (Optional.) Desired NumPy type for the returned value. 226 227 Returns: 228 A numpy array. 229 """ 230 result = np.asarray(value, dtype=dtype, order="C") 231 if result.dtype.char == "S" and result is not value: 232 return np.asarray(value, order="C", dtype=object) 233 elif result.dtype.char == "U" and result is not value: 234 value = np.vectorize(lambda x: x.encode("utf8"))(value) 235 return np.asarray(value, order="C", dtype=object) 236 elif result.dtype.char == "U": 237 return result.astype(np.bytes_) 238 else: 239 return result 240 241 def __call__(self, token, device, args): 242 """Calls the registered function for `token` with args. 243 244 Args: 245 token: A key into this `FuncRegistry` identifying which function to call. 246 device: Name of the device on which outputs of `token`'s corresponding 247 operation should be placed. Used iff the function registered for `token` 248 is an EagerPyFunc. 249 args: The arguments to pass to the function registered for `token`. 250 251 Returns: 252 The output of the function registered for `token`. 253 254 Raises: 255 ValueError: if no function is registered for `token`. 256 """ 257 func = self.get(token, None) 258 if func is None: 259 raise ValueError(f"Could not find callback with key={token} in the " 260 "registry.") 261 if isinstance(func, EagerFunc): 262 # NB: Different invocations of the same py_func will share the same 263 # token, and the entries they stash in the tape_cache will collide. 264 # In practice, when executing a graph, this should only happen if 265 # the py_func is in a while_loop whose iterations are run in parallel 266 # or if the graph is being driven by concurrent session.run() calls. 267 # 268 # TODO(akshayka): Key the tape cache in a thread-safe way. 269 return func(device, token, args) 270 else: 271 ret = func(*args) 272 # Strings seem to lead to a memory leak here if they're not wrapped in a 273 # list. 274 if isinstance(ret, bytes): 275 ret = [ret] 276 # Ensures that we return either a single numpy array or a list of numpy 277 # arrays. 278 if isinstance(ret, (tuple, list)): 279 return [self._convert(x) for x in ret] 280 else: 281 return self._convert(ret) 282 283 def size(self): 284 """Returns how many functions are currently registered.""" 285 return len(self._funcs) 286 287 def _next_unique_token(self): 288 """Returns a unique token.""" 289 with self._lock: 290 uid = self._unique_id 291 self._unique_id += 1 292 return "pyfunc_%d" % uid 293 294 295# Global registry for py functions. 296_py_funcs = FuncRegistry() 297 298_pywrap_py_func.initialize_py_trampoline(_py_funcs) 299 300 301def _internal_py_func(func, 302 inp, 303 Tout, 304 stateful=None, 305 use_eager_py_func=False, 306 is_grad_func=False, 307 name=None): 308 """See documentation for py_func and eager_py_func.""" 309 if not callable(func): 310 raise ValueError( 311 f"Expected func to be callable. Received func={func} of type " 312 f"{type(func)}.") 313 314 original_func = func 315 func = autograph.do_not_convert(func) 316 inp = variable_utils.convert_variables_to_tensors(list(inp)) 317 318 # Normalize Tout. 319 is_list_or_tuple = isinstance(Tout, (list, tuple)) 320 Tout = Tout if is_list_or_tuple else [Tout] 321 Tout = [_as_dtype_or_type_spec(t) for t in Tout] 322 323 # Check if we need to handle CompositeTensor inputs or outputs. 324 handle_composite_tensors = ( 325 use_eager_py_func and 326 (any(isinstance(v, composite_tensor.CompositeTensor) for v in inp) or 327 any(isinstance(t, type_spec.TypeSpec) for t in Tout))) 328 if handle_composite_tensors: 329 func, inp, Tout, out_structure = _wrap_for_composites(func, inp, Tout) 330 331 if use_eager_py_func: 332 func = EagerFunc(func, Tout, is_grad_func) 333 334 # Tying the registered function's lifetime with the current default graph is 335 # not reliable. For example, Estimator-based binaries may switch graphs in 336 # between model training end evaluation, via saved_model. Those binaries work 337 # because the original function is global, and break once the registered 338 # function is an anonymous lambda, like the one produced by do_not_convert. 339 # To avoid breaking those cases, we attach the wrapper to the original 340 # function so that their lifetime is connected. 341 # TODO(b/144286616): Remove this. 342 if tf_inspect.isfunction(original_func): 343 # Note: this check is needed because original_func may be a descriptor 344 # (https://docs.python.org/3/howto/descriptor.html) 345 # and we can't attach attributes to those. 346 original_func.ag_dnc_wrapper__ = func 347 348 token = _py_funcs.insert(func) 349 # We tie the registered function's lifetime with the current default graph, 350 # i.e., when the current graph is destroyed, we remove its py funcs. 351 graph = ops.get_default_graph() 352 353 while True: 354 current_graph = graph 355 if isinstance(graph, function._FuncGraph): # pylint: disable=protected-access 356 graph = graph._outer_graph # pylint: disable=protected-access 357 elif isinstance(graph, func_graph.FuncGraph): 358 graph = graph.outer_graph 359 if graph is current_graph: 360 break 361 362 # TODO(zhifengc): Consider adding a Graph method to collect 363 # `cleanup` objects in one of its member. 364 if not hasattr(graph, "_py_funcs_used_in_graph"): 365 graph._py_funcs_used_in_graph = [] # pylint: disable=protected-access 366 367 # Store a reference to the function in the graph to ensure it stays alive 368 # as long as the graph lives. When the graph is destroyed, the function 369 # is left to the garbage collector for destruction as well. 370 graph._py_funcs_used_in_graph.append(func) # pylint: disable=protected-access 371 372 if use_eager_py_func: 373 result = gen_script_ops.eager_py_func( 374 input=inp, 375 token=token, 376 is_async=context.is_async(), 377 Tout=Tout, 378 name=name) 379 else: 380 if stateful: 381 result = gen_script_ops.py_func( 382 input=inp, token=token, Tout=Tout, name=name) 383 else: 384 result = gen_script_ops.py_func_stateless( 385 input=inp, token=token, Tout=Tout, name=name) 386 387 if handle_composite_tensors and Tout: 388 result = nest.pack_sequence_as( 389 out_structure, result, expand_composites=True) 390 391 return result if is_list_or_tuple else result[0] 392 393 394# TODO(akshayka): Implement higher-order derivatives. 395@ops.RegisterGradient("EagerPyFunc") 396def _EagerPyFuncGrad(op, *dy): 397 """Computes the gradient of an EagerPyFunc.""" 398 399 token = op.get_attr("token") 400 401 def eagerly_executed_grad(*dy): 402 tape, eager_inputs, eager_outputs = tape_cache.pop(compat.as_bytes(token)) 403 return tape.gradient(eager_outputs, eager_inputs, output_gradients=dy) 404 405 with ops.control_dependencies(op.outputs): 406 gradient_op = _internal_py_func( 407 func=eagerly_executed_grad, 408 inp=dy, 409 Tout=[tensor.dtype for tensor in op.inputs], 410 use_eager_py_func=True, 411 is_grad_func=True) 412 413 if not context.executing_eagerly(): 414 # In graph mode, we find the func object from its token and 415 # notify the eager func object it needs to support the gradients. 416 func = _py_funcs.get(token.decode()) 417 assert isinstance(func, EagerFunc), ( 418 f"EagerPyFuncGrad called on a non-EagerFunc object: {func}.") 419 func.set_support_graph_mode_gradient() 420 return gradient_op 421 422 423@tf_export("py_function") 424@dispatch.add_dispatch_support 425def eager_py_func(func, inp, Tout, name=None): 426 """Wraps a python function into a TensorFlow op that executes it eagerly. 427 428 This function allows expressing computations in a TensorFlow graph as 429 Python functions. In particular, it wraps a Python function `func` 430 in a once-differentiable TensorFlow operation that executes it with eager 431 execution enabled. As a consequence, `tf.py_function` makes it 432 possible to express control flow using Python constructs (`if`, `while`, 433 `for`, etc.), instead of TensorFlow control flow constructs (`tf.cond`, 434 `tf.while_loop`). For example, you might use `tf.py_function` to 435 implement the log huber function: 436 437 ```python 438 def log_huber(x, m): 439 if tf.abs(x) <= m: 440 return x**2 441 else: 442 return m**2 * (1 - 2 * tf.math.log(m) + tf.math.log(x**2)) 443 444 x = tf.constant(1.0) 445 m = tf.constant(2.0) 446 447 with tf.GradientTape() as t: 448 t.watch([x, m]) 449 y = tf.py_function(func=log_huber, inp=[x, m], Tout=tf.float32) 450 451 dy_dx = t.gradient(y, x) 452 assert dy_dx.numpy() == 2.0 453 ``` 454 455 You can also use `tf.py_function` to debug your models at runtime 456 using Python tools, i.e., you can isolate portions of your code that 457 you want to debug, wrap them in Python functions and insert `pdb` tracepoints 458 or print statements as desired, and wrap those functions in 459 `tf.py_function`. 460 461 For more information on eager execution, see the 462 [Eager guide](https://tensorflow.org/guide/eager). 463 464 `tf.py_function` is similar in spirit to `tf.compat.v1.py_func`, but unlike 465 the latter, the former lets you use TensorFlow operations in the wrapped 466 Python function. In particular, while `tf.compat.v1.py_func` only runs on CPUs 467 and wraps functions that take NumPy arrays as inputs and return NumPy arrays 468 as outputs, `tf.py_function` can be placed on GPUs and wraps functions 469 that take Tensors as inputs, execute TensorFlow operations in their bodies, 470 and return Tensors as outputs. 471 472 Note: We recommend to avoid using `tf.py_function` outside of prototyping 473 and experimentation due to the following known limitations: 474 475 * Calling `tf.py_function` will acquire the Python Global Interpreter Lock 476 (GIL) that allows only one thread to run at any point in time. This will 477 preclude efficient parallelization and distribution of the execution of the 478 program. 479 480 * The body of the function (i.e. `func`) will not be serialized in a 481 `GraphDef`. Therefore, you should not use this function if you need to 482 serialize your model and restore it in a different environment. 483 484 * The operation must run in the same address space as the Python program 485 that calls `tf.py_function()`. If you are using distributed 486 TensorFlow, you must run a `tf.distribute.Server` in the same process as the 487 program that calls `tf.py_function()` and you must pin the created 488 operation to a device in that server (e.g. using `with tf.device():`). 489 490 * Currently `tf.py_function` is not compatible with XLA. Calling 491 `tf.py_function` inside `tf.function(jit_compile=True)` will raise an 492 error. 493 494 Args: 495 func: A Python function that accepts `inp` as arguments, and returns a 496 value (or list of values) whose type is described by `Tout`. 497 498 inp: Input arguments for `func`. A list whose elements are `Tensor`s or 499 `CompositeTensors` (such as `tf.RaggedTensor`); or a single `Tensor` or 500 `CompositeTensor`. 501 502 Tout: The type(s) of the value(s) returned by `func`. One of the 503 following. 504 505 * If `func` returns a `Tensor` (or a value that can be converted to a 506 Tensor): the `tf.DType` for that value. 507 * If `func` returns a `CompositeTensor`: The `tf.TypeSpec` for that value. 508 * If `func` returns `None`: the empty list (`[]`). 509 * If `func` returns a list of `Tensor` and `CompositeTensor` values: 510 a corresponding list of `tf.DType`s and `tf.TypeSpec`s for each value. 511 512 name: A name for the operation (optional). 513 514 Returns: 515 The value(s) computed by `func`: a `Tensor`, `CompositeTensor`, or list of 516 `Tensor` and `CompositeTensor`; or an empty list if `func` returns `None`. 517 """ 518 if ops.executing_eagerly_outside_functions(): 519 with ops.device(context.context().host_address_space()): 520 return _internal_py_func( 521 func=func, inp=inp, Tout=Tout, use_eager_py_func=True, name=name) 522 523 return _internal_py_func( 524 func=func, inp=inp, Tout=Tout, use_eager_py_func=True, name=name) 525 526 527def py_func_common(func, inp, Tout, stateful=True, name=None): 528 """Wraps a python function and uses it as a TensorFlow op. 529 530 Given a python function `func`, which takes numpy arrays as its 531 arguments and returns numpy arrays as its outputs, wrap this function as an 532 operation in a TensorFlow graph. The following snippet constructs a simple 533 TensorFlow graph that invokes the `np.sinh()` NumPy function as a operation 534 in the graph: 535 536 ```python 537 def my_func(x): 538 # x will be a numpy array with the contents of the placeholder below 539 return np.sinh(x) 540 input = tf.compat.v1.placeholder(tf.float32) 541 y = tf.compat.v1.py_func(my_func, [input], tf.float32) 542 ``` 543 544 **N.B.** The `tf.compat.v1.py_func()` operation has the following known 545 limitations: 546 547 * The body of the function (i.e. `func`) will not be serialized in a 548 `GraphDef`. Therefore, you should not use this function if you need to 549 serialize your model and restore it in a different environment. 550 551 * The operation must run in the same address space as the Python program 552 that calls `tf.compat.v1.py_func()`. If you are using distributed 553 TensorFlow, you 554 must run a `tf.distribute.Server` in the same process as the program that 555 calls 556 `tf.compat.v1.py_func()` and you must pin the created operation to a device 557 in that 558 server (e.g. using `with tf.device():`). 559 560 Note: It produces tensors of unknown shape and rank as shape inference 561 does not work on arbitrary Python code. 562 If you need the shape, you need to set it based on statically 563 available information. 564 565 E.g. 566 ```python 567 import tensorflow as tf 568 import numpy as np 569 570 def make_synthetic_data(i): 571 return np.cast[np.uint8](i) * np.ones([20,256,256,3], 572 dtype=np.float32) / 10. 573 574 def preprocess_fn(i): 575 ones = tf.py_function(make_synthetic_data,[i],tf.float32) 576 ones.set_shape(tf.TensorShape([None, None, None, None])) 577 ones = tf.image.resize(ones, [224,224]) 578 return ones 579 580 ds = tf.data.Dataset.range(10) 581 ds = ds.map(preprocess_fn) 582 ``` 583 584 Args: 585 func: A Python function, which accepts `ndarray` objects as arguments and 586 returns a list of `ndarray` objects (or a single `ndarray`). This function 587 must accept as many arguments as there are tensors in `inp`, and these 588 argument types will match the corresponding `tf.Tensor` objects in `inp`. 589 The returns `ndarray`s must match the number and types defined `Tout`. 590 Important Note: Input and output numpy `ndarray`s of `func` are not 591 guaranteed to be copies. In some cases their underlying memory will be 592 shared with the corresponding TensorFlow tensors. In-place modification 593 or storing `func` input or return values in python datastructures 594 without explicit (np.)copy can have non-deterministic consequences. 595 inp: A list of `Tensor` objects. 596 Tout: A list or tuple of tensorflow data types or a single tensorflow data 597 type if there is only one, indicating what `func` returns. 598 stateful: (Boolean.) If True, the function should be considered stateful. If 599 a function is stateless, when given the same input it will return the same 600 output and have no observable side effects. Optimizations such as common 601 subexpression elimination are only performed on stateless operations. 602 name: A name for the operation (optional). 603 604 Returns: 605 A list of `Tensor` or a single `Tensor` which `func` computes. 606 607 @compatibility(TF2) 608 609 This name was deprecated and removed in TF2, but `tf.numpy_function` is a 610 near-exact replacement, just drop the `stateful` argument (all 611 `tf.numpy_function` calls are considered stateful). It is compatible with 612 eager execution and `tf.function`. 613 614 `tf.py_function` is a close but not an exact replacement, passing TensorFlow 615 tensors to the wrapped function instead of NumPy arrays, which provides 616 gradients and can take advantage of accelerators. 617 618 Before: 619 620 >>> def fn_using_numpy(x): 621 ... x[0] = 0. 622 ... return x 623 >>> tf.compat.v1.py_func(fn_using_numpy, inp=[tf.constant([1., 2.])], 624 ... Tout=tf.float32, stateful=False) 625 <tf.Tensor: shape=(2,), dtype=float32, numpy=array([0., 2.], dtype=float32)> 626 627 After: 628 629 >>> tf.numpy_function(fn_using_numpy, inp=[tf.constant([1., 2.])], 630 ... Tout=tf.float32) 631 <tf.Tensor: shape=(2,), dtype=float32, numpy=array([0., 2.], dtype=float32)> 632 633 @end_compatibility 634 635 """ 636 if context.executing_eagerly(): 637 result = func(*[np.array(x) for x in inp]) 638 result = nest.flatten(result) 639 640 result = [x if x is None else ops.convert_to_tensor(x) for x in result] 641 if len(result) == 1: 642 # Mimic the automatic unwrapping in graph-mode py_func 643 result, = result 644 return result 645 646 if ops.executing_eagerly_outside_functions(): 647 with ops.device(context.context().host_address_space()): 648 return _internal_py_func( 649 func=func, 650 inp=inp, 651 Tout=Tout, 652 stateful=stateful, 653 use_eager_py_func=False, 654 name=name) 655 656 return _internal_py_func( 657 func=func, 658 inp=inp, 659 Tout=Tout, 660 stateful=stateful, 661 use_eager_py_func=False, 662 name=name) 663 664 665@deprecation.deprecated( 666 date=None, 667 instructions="""tf.py_func is deprecated in TF V2. Instead, there are two 668 options available in V2. 669 - tf.py_function takes a python function which manipulates tf eager 670 tensors instead of numpy arrays. It's easy to convert a tf eager tensor to 671 an ndarray (just call tensor.numpy()) but having access to eager tensors 672 means `tf.py_function`s can use accelerators such as GPUs as well as 673 being differentiable using a gradient tape. 674 - tf.numpy_function maintains the semantics of the deprecated tf.py_func 675 (it is not differentiable, and manipulates numpy arrays). It drops the 676 stateful argument making all functions stateful. 677 """) 678@tf_export(v1=["py_func"]) 679@dispatch.add_dispatch_support 680def py_func(func, inp, Tout, stateful=True, name=None): 681 return py_func_common(func, inp, Tout, stateful, name=name) 682 683 684py_func.__doc__ = "%s" % py_func_common.__doc__ 685 686 687@tf_export("numpy_function") 688@dispatch.add_dispatch_support 689def numpy_function(func, inp, Tout, stateful=True, name=None): 690 """Wraps a python function and uses it as a TensorFlow op. 691 692 Given a python function `func` wrap this function as an operation in a 693 TensorFlow function. `func` must take numpy arrays as its arguments and 694 return numpy arrays as its outputs. 695 696 The following example creates a TensorFlow graph with `np.sinh()` as an 697 operation in the graph: 698 699 >>> def my_numpy_func(x): 700 ... # x will be a numpy array with the contents of the input to the 701 ... # tf.function 702 ... return np.sinh(x) 703 >>> @tf.function(input_signature=[tf.TensorSpec(None, tf.float32)]) 704 ... def tf_function(input): 705 ... y = tf.numpy_function(my_numpy_func, [input], tf.float32) 706 ... return y * y 707 >>> tf_function(tf.constant(1.)) 708 <tf.Tensor: shape=(), dtype=float32, numpy=1.3810978> 709 710 Comparison to `tf.py_function`: 711 `tf.py_function` and `tf.numpy_function` are very similar, except that 712 `tf.numpy_function` takes numpy arrays, and not `tf.Tensor`s. If you want the 713 function to contain `tf.Tensors`, and have any TensorFlow operations executed 714 in the function be differentiable, please use `tf.py_function`. 715 716 Note: We recommend to avoid using `tf.numpy_function` outside of 717 prototyping and experimentation due to the following known limitations: 718 719 * Calling `tf.numpy_function` will acquire the Python Global Interpreter Lock 720 (GIL) that allows only one thread to run at any point in time. This will 721 preclude efficient parallelization and distribution of the execution of the 722 program. Therefore, you are discouraged to use `tf.numpy_function` outside 723 of prototyping and experimentation. 724 725 * The body of the function (i.e. `func`) will not be serialized in a 726 `tf.SavedModel`. Therefore, you should not use this function if you need to 727 serialize your model and restore it in a different environment. 728 729 * The operation must run in the same address space as the Python program 730 that calls `tf.numpy_function()`. If you are using distributed 731 TensorFlow, you must run a `tf.distribute.Server` in the same process as the 732 program that calls `tf.numpy_function` you must pin the created 733 operation to a device in that server (e.g. using `with tf.device():`). 734 735 * Currently `tf.numpy_function` is not compatible with XLA. Calling 736 `tf.numpy_function` inside `tf.function(jit_compile=True)` will raise an 737 error. 738 739 * Since the function takes numpy arrays, you cannot take gradients 740 through a numpy_function. If you require something that is differentiable, 741 please consider using tf.py_function. 742 743 Args: 744 func: A Python function, which accepts `numpy.ndarray` objects as arguments 745 and returns a list of `numpy.ndarray` objects (or a single 746 `numpy.ndarray`). This function must accept as many arguments as there are 747 tensors in `inp`, and these argument types will match the corresponding 748 `tf.Tensor` objects in `inp`. The returns `numpy.ndarray`s must match the 749 number and types defined `Tout`. 750 Important Note: Input and output `numpy.ndarray`s of `func` are not 751 guaranteed to be copies. In some cases their underlying memory will be 752 shared with the corresponding TensorFlow tensors. In-place modification 753 or storing `func` input or return values in python datastructures 754 without explicit (np.)copy can have non-deterministic consequences. 755 inp: A list of `tf.Tensor` objects. 756 Tout: A list or tuple of tensorflow data types or a single tensorflow data 757 type if there is only one, indicating what `func` returns. 758 stateful: (Boolean.) Setting this argument to False tells the runtime to 759 treat the function as stateless, which enables certain optimizations. 760 A function is stateless when given the same input it will return the 761 same output and have no side effects; its only purpose is to have a 762 return value. 763 The behavior for a stateful function with the `stateful` argument False 764 is undefined. In particular, caution should be taken when 765 mutating the input arguments as this is a stateful operation. 766 name: (Optional) A name for the operation. 767 768 Returns: 769 Single or list of `tf.Tensor` which `func` computes. 770 """ 771 return py_func_common(func, inp, Tout, stateful=stateful, name=name) 772 773 774def _as_dtype_or_type_spec(t): 775 return t if isinstance(t, type_spec.TypeSpec) else dtypes.as_dtype(t) 776 777 778def _wrap_for_composites(func, inp, Tout): 779 """Wraps user inputs to support composite tensors for `py_function`. 780 781 1. Flattens `inp` to a list of Tensors (by flattening any composite tensors). 782 2. Creates a wrapper fuction for `func` that expects flat inputs and: 783 - Packs the inputs into the input structure expected by `func`. 784 - Calls `func` with the packed inputs. 785 - Checks that `func`'s output matches `Tout`. 786 - Flattens func`'s output to a list of Tensors (flattening any composite 787 tensors). 788 789 Args: 790 func: The function to wrap (`func` argument to `py_function`). 791 inp: The input arguments for func (`inp` argument to `py_function`). 792 Tout: The expected output types for func (`Tout` argument to `py_function). 793 794 Returns: 795 A tuple `(func, inp, Tout, out_structure)`, where `func` is the wrapped 796 function, `inp` is the flattened inputs, `Tout` is the list of expected 797 dtypes for the flattened outputs, and `out_structure` is the expected 798 output structure (which can be used to pack the output tensors). 799 """ 800 in_structure = [ 801 v if isinstance(v, composite_tensor.CompositeTensor) else 1 for v in inp 802 ] 803 inp = nest.flatten_up_to(in_structure, inp, expand_composites=True) 804 out_structure = Tout 805 Tout = [ 806 v.dtype if isinstance(v, tensor_spec.TensorSpec) else v 807 for v in nest.flatten(Tout, expand_composites=True) 808 ] 809 810 def wrapped_func(*flat_inp): 811 structured_inp = nest.pack_sequence_as( 812 in_structure, flat_inp, expand_composites=True) 813 out = func(*structured_inp) 814 if not out_structure: 815 return [] # Ignore return value if none is requested/expected. 816 if not isinstance(out, (list, tuple)): 817 out = [out] # func may return a single value instead of a list. 818 flat_out = [] 819 for elt, expected_type in zip(out, out_structure): 820 if (isinstance(expected_type, type_spec.TypeSpec) and 821 not isinstance(expected_type, tensor_spec.TensorSpec)): 822 if not expected_type.is_compatible_with(elt): 823 # pylint: disable=protected-access 824 raise ValueError( 825 f"py_function: func={func} returned {out!r}, " 826 f"which did not match Tout={out_structure!r}.\nIn particular, " 827 f"{elt!r} is not compatible with {expected_type!r}.") 828 flat_out.extend(nest.flatten(elt, expand_composites=True)) 829 else: 830 # Pro-actively check if the return value is a composite tensor when 831 # we expect a Tensor. We would catch this later (when we call 832 # convert_to_tensor), but checking it here lets us give a better 833 # error message. 834 if isinstance(elt, composite_tensor.CompositeTensor): 835 raise ValueError( 836 f"py_function: func={func} returned {out!r}, " 837 f"which did not match Tout={out_structure!r}.\nIn particular, " 838 f"{elt!r} is not a Tensor.") 839 flat_out.append(elt) 840 return flat_out 841 842 return wrapped_func, inp, Tout, out_structure 843 844 845ops.NotDifferentiable("PyFunc") 846ops.NotDifferentiable("PyFuncStateless") 847