xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/function.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Python front-end supports for functions.
16
17NOTE: At this time, functions are experimental and subject to change!. Proceed
18with caution.
19"""
20
21import collections
22import hashlib
23
24from tensorflow.core.framework import attr_value_pb2
25from tensorflow.core.framework import function_pb2
26from tensorflow.python.client import pywrap_tf_session as c_api
27from tensorflow.python.eager import context
28from tensorflow.python.framework import c_api_util
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import graph_to_function_def
31from tensorflow.python.framework import ops
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import resource_variable_ops
34from tensorflow.python.ops import variable_scope as vs
35from tensorflow.python.util import compat
36from tensorflow.python.util import function_utils
37from tensorflow.python.util import tf_contextlib
38from tensorflow.python.util import tf_inspect
39
40
41# TODO(b/136040013): Drop support for Defun.
42class Defun(object):
43  """Obsolete. Slated for deletion. Please use tf.function instead.
44
45  Known feature gaps while migrating to tf.function (could be outdated):
46  - tf.function doesn’t support Send/Recv capability since it doesn’t share
47    rendezvous with the main graph but always creates a new one.
48  - tf.function doesn’t support custom gradient function directly, instead you
49    need to define the function inside a tf.custom_gradient wrapper together
50    with the gradient function.
51  - Unlike Defun, Keras layers used inside a tf.function need to be created only
52    once to avoid variable recreation.
53  - Defun respects the device assignments and applies them to the function body
54    but tf.function needs it to be done manually.
55  - Defun might prune out unused ops automatically but tf.function doesn't.
56
57  Limitations of Defun:
58  - Original source locations are not preserved so errors do not include
59    full/valid stack traces.
60  - Only supports linear sequence of arguments and return values, putting the
61    burden on the caller to pack/unpack everything across a Defun boundary into
62    tuples (as opposed to passing list and dict-like structures directly).
63  - Does not support overloading or late-bound specializations.
64  - Has its own way for defining gradient overrides which does not follow
65    current conventions.
66  - Cannot support imperative control flow or automatic control dependencies.
67  - Does not reflect statefulness in the graph and has a calling convention that
68    differs from how more modern tools interact.
69  - Is only compatible with graph building mode.
70
71  Decorator used to define TensorFlow functions.
72
73  Use this decorator to make a Python function usable directly as a TensorFlow
74  function.
75
76  The decorated function must add ops to the default graph and return zero or
77  more `Tensor` objects.  Call the decorator with named arguments, one for each
78  argument of the function to decorate, with the expected type of the argument
79  as value.
80
81  For example if the function to decorate accepts two `tf.float32` arguments
82  named `x` and `y`, call the decorator with:
83
84      @Defun(tf.float32, tf.float32)
85      def foo(x, y):
86        ...
87
88  When you call the decorated function, it adds the `call` ops to the
89  default graph. In addition, it adds the definition of the function into the
90  default graph. Because the addition of the function into the graph
91  is deferred, the decorator can be used anywhere in the program.
92
93  Any variables created inside of the function are hoisted into the outer graph.
94  Note that the variables are created in the variable scope that was active
95  during the first call to the function. Subsequent function calls will refer to
96  the same set of variables.
97
98  Definitions of functions in a graph are frozen as soon as the graph is used to
99  create a session. However, new functions and new calls to existing functions
100  may be added to the graph, with the new functions themselves becoming
101  immediately frozen.
102
103  Example, but also see the [How To on functions](link_needed).
104
105  ```python
106  # Defining the function.
107  @tf.Defun(tf.float32, tf.float32)
108  def MyFunc(x, y):
109    return x + y, x - y
110
111  # Building the graph.
112  a = tf.constant([1.0])
113  b = tf.constant([2.0])
114  c, d = MyFunc(a, b, name='mycall')
115  ```
116  """
117
118  def __init__(self, *input_types, **kwargs):
119    """Create a `Defun` decorator.
120
121    Args:
122      *input_types: A list of `tf.DType`
123      **kwargs: Optional keyword arguments, including
124         func_name - (optional).  A python string, the name to use to
125           declare this `Function` in the graph.
126
127         grad_func - (optional).  A function implementing the gradient
128           of the function-to-register.  This is must be a
129           `_DefinedFunction` object. The gradient
130           function must satisfy the criterion defined in
131           function.proto:GradientDef.
132
133         python_grad_func - (optional).  A function implementing the
134           gradient of the function python-side. This function must
135           take the current op and the gradients w.r.t. its outputs,
136           and return the gradients w.r.t. the inputs. That is it must
137           implement the interface expected by `tf.RegisterGradient`).
138           This will be called by tf.gradients to add the gradient ops
139           to the graph. At most one of grad_func and python_grad_func
140           can be specified.
141
142         out_names = (optional). A list of strings, one per output
143           tensor.
144
145         shape_func - (optional). A function taking the op and returning a list
146           of static shapes to set for the function's outputs.
147    """
148    self._input_types = input_types
149    self._func_name = kwargs.pop("func_name", None)
150    self._grad_func = kwargs.pop("grad_func", None)
151    self._python_grad_func = kwargs.pop("python_grad_func", None)
152    self._out_names = kwargs.pop("out_names", None)
153    self._extra_kwargs = kwargs
154
155  def __call__(self, func):
156    # Various sanity checks on the callable func.
157    if not callable(func):
158      raise ValueError(f"Function {func} must be a callable.")
159
160    # Func should not use kwargs and defaults.
161    argspec = tf_inspect.getargspec(func)
162    if argspec.keywords or argspec.defaults:
163      raise ValueError(
164          "Functions with argument defaults or keywords arguments are not "
165          f"supported. {func} has defaults {argspec.defaults} and keywords "
166          f"{argspec.keywords}.")
167
168    # Computes how many arguments 'func' has.
169    min_args = len(argspec.args)
170    max_args = min_args
171    if argspec.varargs:
172      max_args = 1000000
173    argnames = argspec.args
174    if tf_inspect.ismethod(func):
175      # 1st argument is the "class" type.
176      min_args -= 1
177      argnames = argnames[1:]
178
179    if self._input_types:
180      # If Defun is given a list of types for the inputs, the number
181      # of input types should be compatible with 'func'.
182      num = len(self._input_types)
183      if num < min_args or num > max_args:
184        raise ValueError(
185            "The number of tf.function input types is not compatible with the "
186            f"allowed arguments of {func}. The tf.function have {num} input "
187            f"types, while the python function allows minimum {min_args} and "
188            f"maximum {max_args} arguments.")
189      return _DefinedFunction(
190          func,
191          argnames,
192          self._input_types,
193          self._func_name,
194          self._grad_func,
195          self._python_grad_func,
196          out_names=self._out_names,
197          **self._extra_kwargs)
198
199    # 'func' expects no arguments and input types is an empty list.
200    if min_args == 0 and max_args == 0:
201      return _DefinedFunction(
202          func, [], [],
203          self._func_name,
204          self._grad_func,
205          self._python_grad_func,
206          out_names=self._out_names,
207          **self._extra_kwargs)
208
209    # Input types are unknown. It's an overloaded function and hence
210    # its definition needs to be deferred until it's called.
211    return _OverloadedFunction(
212        func,
213        argnames,
214        self._func_name,
215        self._grad_func,
216        self._python_grad_func,
217        out_names=self._out_names,
218        **self._extra_kwargs)
219
220
221class _DefinedFunctionDeleter(object):
222  """Unregister function from eager context."""
223
224  __slots__ = ["name"]
225
226  def __init__(self, name):
227    self.name = name
228
229  def __del__(self):
230    try:
231      context.remove_function(self.name)
232    except TypeError:
233      # Suppress some exceptions, mainly for the case when we're running on
234      # module deletion. Things that can go wrong include the context module
235      # already being unloaded, self._handle._handle_data no longer being
236      # valid, and so on. Printing warnings in these cases is silly
237      # (exceptions raised from __del__ are printed as warnings to stderr).
238      pass  # 'NoneType' object is not callable when the handle has been
239      # partially unloaded.
240    except AttributeError:
241      pass  # 'NoneType' object has no attribute 'eager_mode' when context has
242      # been unloaded. Will catch other module unloads as well.
243
244
245class _DefinedFunction(object):
246  """_DefinedFunction encapsulates a function definition and its properties.
247
248  Attributes:
249    name: The function name.
250    definition: The definition of this function. A FunctionDef proto.
251    grad_func_name: If not None, the name of this function's gradient function.
252    python_grad_func: A python callable implementing the gradient of
253      the function python-side.
254  """
255
256  def __init__(self,
257               func,
258               argnames,
259               input_types,
260               func_name=None,
261               grad_func=None,
262               python_grad_func=None,
263               out_names=None,
264               shape_func=None,
265               capture_by_value=False,
266               allowlisted_stateful_ops=None,
267               capture_resource_var_by_value=True,
268               **kwargs):
269    """Creates _DefinedFunction.
270
271    Args:
272      func:  A python callable which constructs a tf function body.
273      argnames: A list of strings for function argument names.
274      input_types: The function's argument types. Can be a tuple, list of
275        tf data types.
276      func_name: The function name. Defaults to None, in which derives from
277        'func'.
278      grad_func: This function's gradient function, if not None. Defaults
279        to None.
280      python_grad_func: A python callable implementing the gradient of
281        the function python-side.
282      out_names: An optional list of strings for the function return value
283        names.
284      shape_func: An optional function mapping an op to a list of static
285        output shapes.
286      capture_by_value: Boolean (defaults to False). If True, captured values
287        will be copied into the function body.
288      allowlisted_stateful_ops: A set of ops that if stateful we ignore and
289        copy into the function body, when `capture_by_value` is True.
290      capture_resource_var_by_value: Boolean (defaults to True). If False,
291        captured resource variable returns the handle instead of value.
292      **kwargs: The keyword arguments. **kwargs is passed to every call
293        site of this function.
294
295    Raises:
296      ValueError: The function definition is invalid.
297
298    """
299    self._func = func
300    self._input_types = input_types
301    self._func_name = func_name
302    self._grad_func = grad_func
303    self._python_grad_func = python_grad_func
304    self._out_names = out_names
305    self._shape_func = shape_func
306    self._capture_by_value = capture_by_value
307    self._allowlisted_stateful_ops = allowlisted_stateful_ops
308    if self._allowlisted_stateful_ops is None:
309      self._allowlisted_stateful_ops = set()
310    self._capture_resource_var_by_value = capture_resource_var_by_value
311    self._extra_kwargs = kwargs
312    # Constructed only when C API is disabled, lazily
313    self._definition = None
314    # Constructed only when C API is enabled, lazily
315    self._c_func = None
316    self._function_deleter = None
317    self._sub_functions = {}  # Constructed with _definition or _c_func
318    # pylint: disable=protected-access
319    device_funcs = ops.get_default_graph()._device_functions_outer_to_inner
320    # pylint: enable=protected-access
321
322    # Get the innermost device if possible.
323    self._caller_device = device_funcs[-1] if device_funcs else None
324
325    # Cached OpDef for this function. When C API is enabled, this is
326    # the only part of FunctionDef that we cache in Python. When C API
327    # is disabled the whole _definition is available and this is simply
328    # another reference to _definition.signature
329    self._op_def = None
330
331    assert isinstance(input_types, (list, tuple))
332    self._arg_types = input_types
333    self._arg_names = [argnames[i] if i < len(argnames) else ("arg%d" % i)
334                       for i in range(len(input_types))]
335
336  @property
337  def name(self):
338    """Function name."""
339    self._create_definition_if_needed()
340    return self._func_name
341
342  @property
343  def definition(self):
344    """Function definition proto."""
345    self._create_definition_if_needed()
346    if self._c_func:
347      with c_api_util.tf_buffer() as buf:
348        with self._c_func.get() as func:
349          c_api.TF_FunctionToFunctionDef(func, buf)
350          fdef = function_pb2.FunctionDef()
351          proto_data = c_api.TF_GetBuffer(buf)
352          fdef.ParseFromString(compat.as_bytes(proto_data))
353          with ops.init_scope():
354            if context.executing_eagerly():
355              context.add_function(func)
356              self._function_deleter = _DefinedFunctionDeleter(
357                  fdef.signature.name)
358      return fdef
359    return self._definition
360
361  @property
362  def _signature(self):
363    self._create_definition_if_needed()
364    return self._op_def
365
366  def set_grad_func(self, grad_func):
367    """Specifies the gradient function of this function."""
368    assert not self._grad_func
369    assert isinstance(grad_func, _DefinedFunction)
370    self._grad_func = grad_func
371
372  @property
373  def grad_func_name(self):
374    """Returns the name of the gradient function."""
375    return self._grad_func.name if self._grad_func else None
376
377  @property
378  def python_grad_func(self):
379    """Python gradient function callable."""
380    return self._python_grad_func
381
382  @property
383  def declared_input_types(self):
384    """Returns the list of data types of explicit declared inputs."""
385    return self._input_types
386
387  @property
388  def captured_inputs(self):
389    """Returns the list of implicitly captured inputs."""
390    self._create_definition_if_needed()
391    return self._extra_inputs
392
393  @property
394  def stateful_ops(self):
395    """Returns the list of stateful ops in function definition.
396
397    Returns:
398      A list of (op.name, op.type) pairs.
399    """
400    self._create_definition_if_needed()
401    return self._stateful_ops
402
403  def _create_definition_if_needed(self):
404    """Creates the function definition if it's not created yet."""
405    with context.graph_mode():
406      self._create_definition_if_needed_impl()
407
408  def _create_definition_if_needed_impl(self):
409    """This is not what you want, see _create_definition_if_needed."""
410    if self._definition is not None or self._c_func is not None:
411      return
412
413    # Copy variable collections (by reference) from the parent graph such that
414    # name based variable sharing (e.g. via tf.make_template) works between the
415    # func graph and parent graph.
416    variable_keys = []
417    variable_keys.extend(ops.GraphKeys._VARIABLE_COLLECTIONS)  # pylint: disable=protected-access
418    variable_keys.append(vs._VARSTORE_KEY)  # pylint: disable=protected-access
419
420    parent_graph = ops.get_default_graph()
421    collections_ref = {
422        key: parent_graph.get_collection_ref(key) for key in variable_keys}
423
424    temp_graph = func_graph_from_py_func(
425        self._func,
426        self._arg_names,
427        self._arg_types,
428        self._func_name,
429        self._capture_by_value,
430        self._caller_device,
431        collections_ref=collections_ref,
432        allowlisted_stateful_ops=self._allowlisted_stateful_ops,
433        capture_resource_var_by_value=self._capture_resource_var_by_value)
434
435    self._extra_inputs = temp_graph.extra_inputs
436    # pylint: disable=protected-access
437    self._sub_functions = temp_graph._functions
438    # pylint: enable=protected-access
439
440    # Extra kwargs are treated as attrs on the function def.
441    if self._func_name:
442      base_func_name = self._func_name
443    else:
444      base_func_name = function_utils.get_func_name(self._func)
445      if self._grad_func:
446        base_func_name += ("_%s" % self._grad_func.name)
447    kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs)
448
449    # FIXME(feyu): C API is always enabled now. The if-true branch never runs.
450    if not temp_graph._c_graph:  # pylint: disable=protected-access
451      # Build the FunctionDef
452      self._definition = graph_to_function_def.graph_to_function_def(
453          temp_graph,
454          temp_graph.get_operations(),
455          temp_graph.inputs,
456          temp_graph.outputs,
457          out_names=self._out_names)
458
459      for k in kwargs_attr:
460        self._definition.attr[k].CopyFrom(kwargs_attr[k])
461
462      # Hash the definition and its dependencies.
463      self._hash_str = self._create_hash_str(
464          self._definition.signature.input_arg,
465          self._definition.signature.output_arg, self._definition.node_def)
466
467      # Finally, we decide the function name to use.  If not specified,
468      # make up something which is almost certainly unique (but deterministic).
469      if not self._func_name:
470        self._func_name = "_".join([base_func_name, self._hash_str])
471      self._definition.signature.name = self._func_name
472      if self._func.__doc__:
473        self._definition.signature.description = self._func.__doc__
474
475      self._op_def = self._definition.signature
476    else:  # C API is enabled
477      output_names = ([compat.as_bytes(x) for x in self._out_names]
478                      if self._out_names else [])
479      description = self._func.__doc__ or None
480      # pylint: disable=protected-access
481      with temp_graph._c_graph.get() as c_graph:
482        c_func = c_api.TF_GraphToFunction_wrapper(
483            c_graph,
484            base_func_name,
485            self._func_name is None,  # append_hash_to_fn_name
486            None,  # opers
487            [t._as_tf_output() for t in temp_graph.inputs],
488            [t._as_tf_output() for t in temp_graph.outputs],
489            output_names,
490            [],  # control_outputs
491            [],  # control_output_names
492            None,  # opts
493            description)
494      self._c_func = c_api_util.ScopedTFFunction(c_func, base_func_name)
495      # pylint: enable=protected-access
496      self._set_c_attrs(kwargs_attr)
497
498      # Set cached fields: _op_def and _func_name (if not already set)
499      self._op_def = self.definition.signature
500      if self._func_name:
501        assert self._func_name == self._op_def.name
502      else:
503        self._func_name = compat.as_str(self._op_def.name)
504
505    self._stateful_ops = [(op.name, op.type)
506                          for op in temp_graph.get_operations()
507                          if op._is_stateful]  # pylint: disable=protected-access
508
509  def _set_c_attrs(self, attrs):
510    """Sets `attrs` as attributes of self._c_func.
511
512    Requires that self._c_func is not None.
513
514    Args:
515      attrs: a dictionary from attribute name to attribute proto value
516    """
517    for name, attr_value in attrs.items():
518      serialized = attr_value.SerializeToString()
519      # TODO(skyewm): this creates and deletes a new TF_Status for every attr.
520      # It might be worth creating a convenient way to re-use the same status.
521      with self._c_func.get() as func:
522        c_api.TF_FunctionSetAttrValueProto(func, compat.as_str(name),
523                                           serialized)
524
525  def _create_hash_str(self, input_arg, output_arg, node_def):
526    """Creates an 8-character string unique to this input.
527
528    Args:
529      input_arg: the input_arg field of an OpDef
530                 (e.g. self._definition.signature.input_arg)
531      output_arg: the output_arg field of an OpDef
532                 (e.g. self._definition.signature.output_arg)
533      node_def: the node_def field of a FunctionDef
534                (e.g. self._definition.node_def)
535
536    Returns:
537      The unique string for this input
538    """
539    hasher = hashlib.sha1()
540
541    def update_num(n):
542      hasher.update(compat.as_bytes("%x" % n))
543
544    def update_str(s):
545      update_num(len(s))
546      hasher.update(compat.as_bytes(s))
547
548    def update_strs(slist):
549      update_num(len(slist))
550      for s in slist:
551        update_str(s)
552
553    for adef in input_arg:
554      update_str(adef.SerializeToString())
555
556    for adef in output_arg:
557      update_str(adef.SerializeToString())
558
559    for n in sorted(node_def, key=lambda n: n.name):
560      update_str(n.name)
561      update_str(n.op)
562      update_strs(n.input)
563      update_num(len(n.attr))
564      # NOTE: protobuf map serialization does not guarantee ordering.
565      for k in sorted(n.attr):
566        update_str(k)
567        update_str(n.attr[k].SerializeToString())
568
569    return hasher.hexdigest()[:8]
570
571  def add_to_graph(self, g):
572    """Adds this function into the graph g."""
573    self._create_definition_if_needed()
574
575    # Adds this function into 'g'.
576    # pylint: disable=protected-access
577    if context.executing_eagerly():
578      context.context().add_function_def(self.definition)
579    else:
580      g._add_function(self)
581    # pylint: enable=protected-access
582
583    # Ensures related sub-routines are defined in 'g', too.
584    for f in self._sub_functions.values():
585      f.add_to_graph(g)
586
587    # Adds its gradient function, too.
588    if self._grad_func:
589      self._grad_func.add_to_graph(g)
590
591  def __call__(self, *args, **kwargs):
592    self.add_to_graph(ops.get_default_graph())
593    args = [ops.convert_to_tensor(_) for _ in args] + self._extra_inputs
594    ret, op = _call(self._signature, *args, **kwargs)
595
596    # Set a hidden attr in 'op' so that gradients_impl can refer back
597    # to this _DefinedFunction instance to access python_grad_func.
598    assert isinstance(op, ops.Operation)
599    setattr(op, "__defun", self)
600
601    if self._shape_func is not None:
602      shapes = self._shape_func(op)
603      if len(shapes) != len(op.outputs):
604        raise ValueError(f"shape_func {self._shape_func} produced "
605                         f"{len(shapes):d} shapes, which does not match "
606                         f"{len(op.outputs)} outputs.")
607      for (t, shape) in zip(op.outputs, shapes):
608        t.set_shape(shape)
609    return ret
610
611
612class _OverloadedFunction(object):
613  """_OverloadedFunction encapsulates an overloaded function.
614
615  _OverloadedFunction maintains a mapping from input types to
616  instantiated _DefinedFunction in self._overload.
617
618  """
619
620  def __init__(self,
621               func,
622               argnames,
623               func_name=None,
624               grad_func=None,
625               python_grad_func=None,
626               out_names=None,
627               **kwargs):
628    """Creates _DefinedFunction.
629
630    Args:
631      func:  A python callable which constructs a tf function body.
632      argnames: A list of strings for function argument names.
633      func_name: The function name. Defaults to None, in which derives from
634        'func'.
635      grad_func: This function's gradient function, if not None. Defaults
636        to None.
637      python_grad_func: A python callable implementing the gradient of
638        the function python-side.
639      out_names: A list of strings for the function return value names.
640      **kwargs: The keyword arguments. **kwargs is passed to every call
641        site of this function.
642
643    Raises:
644      ValueError: The function definition is invalid.
645
646    """
647    self._func = func
648    self._argnames = argnames
649    self._func_name = func_name
650    assert grad_func is None or isinstance(grad_func, _OverloadedFunction)
651    self._grad_func = grad_func
652    self._python_grad_func = python_grad_func
653    self._out_names = out_names
654    self._extra_kwargs = kwargs
655    self._overload = {}
656
657  def instantiate(self, input_types):
658    """Instantiate this function given input argument types.
659
660    Args:
661      input_types: A list of data types for the inputs.
662
663    Returns:
664      _DefinedFunction for the given input types.
665
666    """
667    # Stringify the type list.
668    key = _type_list_to_str(input_types)
669    defined = self._overload.get(key)
670    if not defined:
671      # If not defined yet, define the function given the input types.
672      name = self._func_name
673      if name is not None:
674        name = "_".join([name, key])
675      defined = _DefinedFunction(
676          self._func,
677          self._argnames,
678          input_types,
679          name,
680          None,
681          self._python_grad_func,
682          out_names=self._out_names,
683          **self._extra_kwargs)
684      _ = defined.name  # Fully instantiate the function definition.
685      if self._grad_func:
686        # If _grad_func is given, it is another
687        # _OverloadedFunction. We need to instantiate it with the
688        # right input types.
689        output_types = [
690            dtypes.DType(_.type) for _ in defined._signature.output_arg  # pylint: disable=protected-access
691        ]
692        # pylint: disable=protected-access
693        defined._grad_func = self._grad_func.instantiate(input_types +
694                                                         output_types)
695        # pylint: enable=protected-access
696      self._overload[key] = defined
697    return defined
698
699  def __call__(self, *args, **kwargs):
700    input_types = []
701    args = list(args)
702    for (i, x) in enumerate(args):
703      x = ops.convert_to_tensor(x)
704      if not isinstance(x, ops.Tensor):
705        raise ValueError(f"Expected a Tensor but got {x} with type {type(x)}.")
706      input_types.append(x.dtype)
707      args[i] = x
708    return self.instantiate(input_types)(*args, **kwargs)
709
710
711class _FuncGraph(ops.Graph):
712  """A helper for constructing a function.
713
714  _FuncGraph overrides ops.Graph's create_op() so that we can keep
715  track of all inputs into every op created inside the function.  If
716  any input is from other graphs, we keep track of it in self.capture
717  and substitute the input with a place holder.
718
719  Each captured input's corresponding place holder is converted into a
720  function argument and the caller passes in the captured tensor.
721  """
722
723  def __init__(self, name, capture_by_value, allowlisted_stateful_ops,
724               capture_resource_var_by_value, *args, **kwargs):
725    super(_FuncGraph, self).__init__(*args, **kwargs)
726    self._capture_by_value = capture_by_value
727    self._allowlisted_stateful_ops = allowlisted_stateful_ops
728    self._capture_resource_var_by_value = capture_resource_var_by_value
729    self._building_function = True
730    self._outer_graph = ops.get_default_graph()
731    self._vscope = vs.get_variable_scope()
732    self._old_custom_getter = self._vscope.custom_getter
733
734    # The name of the function.
735    self.name = name
736    # Placeholder tensors representing the inputs to this function. The tensors
737    # are in this _FuncGraph.
738    self.inputs = []
739    # Tensors that will be returned this function. The tensors are in this
740    # _FuncGraph.
741    self.outputs = []
742    # Maps external tensor -> internal tensor (e.g. input placeholder).
743    self._captured = {}
744    # The external tensors that have been captured as inputs and must be passed
745    # to this function (empty if capturing by value, otherwise these are the
746    # keys of _captured).
747    self.extra_inputs = []
748    # Input placeholders that been added for captured values (empty if capturing
749    # by value).
750    self.extra_args = []
751    # Captured variables.
752    # TODO(skyewm): is this needed?
753    self.extra_vars = []
754
755  # pylint: disable=g-doc-return-or-yield
756
757  @property
758  def outer_graph(self):
759    """The graph active when this _FuncGraph was created."""
760    return self._outer_graph
761
762  @tf_contextlib.contextmanager
763  def container(self, container_name):
764    """Returns a context manager that specifies the resource container to use.
765
766    Overridden from `tf.Graph` to update both the init_scope container
767    and the present inner container. This is necessary to make sure setting
768    containers applies correctly both to created variables and to stateful
769    ops.
770
771    Args:
772      container_name: container name string.
773
774    Returns:
775      A context manager for defining resource containers for stateful ops,
776        yields the container name.
777    """
778    original_container = self._container
779    # pylint: disable=protected-access
780    with ops.init_scope():
781      original_init_container = ops.get_default_graph()._container
782    try:
783      self._container = container_name
784      with ops.init_scope():
785        ops.get_default_graph()._container = container_name
786      yield self._container
787    finally:
788      self._container = original_container
789      with ops.init_scope():
790        ops.get_default_graph()._container = original_init_container
791    # pylint: enable=protected-access
792
793  # pylint: enable=g-doc-return-or-yield
794
795  def getvar(
796      self,
797      getter,
798      name,
799      shape=None,
800      dtype=None,
801      initializer=None,
802      reuse=None,
803      trainable=True,
804      collections=None,  # pylint: disable=redefined-outer-name
805      use_resource=None,
806      **kwargs):
807    """A custom variable getter."""
808    # Here, we switch the default graph to the outer graph and ask the
809    # variable scope in which the function is defined to give us the
810    # variable. The variable is stashed in extra_vars and returned to
811    # the caller.
812    #
813    # We capture these variables so that the variable definition is
814    # hoisted upward to the outer most graph.
815    with self._outer_graph.as_default():
816      # pylint: disable=protected-access
817      var = self._vscope.get_variable(
818          vs._get_default_variable_store(),
819          name,
820          shape=shape,
821          dtype=dtype,
822          initializer=initializer,
823          reuse=reuse,
824          trainable=trainable,
825          collections=collections,
826          use_resource=use_resource)
827      self.extra_vars.append(var)
828      if (isinstance(var, resource_variable_ops.BaseResourceVariable) and
829          self._capture_resource_var_by_value):
830        # For resource-based variables read the variable outside the function
831        # and pass in the value. This ensures that the function is pure and
832        # differentiable. TODO(apassos) this may have performance problems if
833        # the function will only do embedding lookups on the variable.
834        return var.value()
835      return var
836
837  def _create_op_internal(
838      self,
839      op_type,
840      inputs,
841      dtypes=None,  # pylint: disable=redefined-outer-name
842      input_types=None,
843      name=None,
844      attrs=None,
845      op_def=None,
846      compute_device=True):
847    for i, x in enumerate(inputs):
848      if isinstance(x, ops.EagerTensor) or x.graph is not self:
849        inputs[i] = self.capture(x)
850    return super(_FuncGraph, self)._create_op_internal(
851        op_type,
852        inputs,
853        dtypes=dtypes,
854        input_types=input_types,
855        name=name,
856        attrs=attrs,
857        op_def=op_def,
858        compute_device=compute_device)
859
860  def capture(self, tensor, name=None):
861    """Adds the given tensor to this graph and returns the captured tensor."""
862    if tensor.ref() in self._captured:
863      # Captured already.
864      return self._captured[tensor.ref()]
865    elif self._capture_by_value:
866      return self._add_tensor_and_parents(tensor)
867    else:
868      return self._capture_tensor_as_extra_input(tensor, name)
869
870  @property
871  def captures(self):
872    """Pairs of tensors and captured tensor."""
873    return [(k.deref(), v) for k, v in self._captured.items()]
874
875  def _capture_tensor_as_extra_input(self, tensor, name=None):
876    # Substitute with a placeholder.
877    self.extra_inputs.append(tensor)
878    # Hoist the new input placeholder out of any control flow context
879    # we're currently in.
880    with ops.control_dependencies(None):
881      ph = array_ops.placeholder(
882          tensor.dtype, shape=tensor.get_shape(), name=name)
883    # pylint: disable=protected-access
884    if isinstance(tensor, ops.EagerTensor):
885      handle_data = tensor._handle_data
886      if handle_data:
887        handle_data = handle_data.SerializeToString()
888    else:
889      with tensor.graph._c_graph.get() as c_graph:
890        handle_data = c_api.GetHandleShapeAndType(c_graph,
891                                                  tensor._as_tf_output())
892
893    if handle_data:
894      with ph.graph._c_graph.get() as c_graph:
895        c_api.SetHandleShapeAndType(c_graph, ph._as_tf_output(),
896                                    compat.as_bytes(handle_data))
897    # pylint: enable=protected-access
898    self.inputs.append(ph)
899    self._captured[tensor.ref()] = ph
900    self.extra_args.append(ph)
901    if _is_guaranteed_const(tensor):
902      with ops.control_dependencies(None):
903        return array_ops.guarantee_const(ph)
904    else:
905      return ph
906
907  def _add_tensor_and_parents(self, tensor):
908    op = self._add_op_and_parents(tensor.op)
909    return op.outputs[tensor.value_index]
910
911  def _add_op_and_parents(self, op):
912    # pylint: disable=protected-access
913    op_def = graph_to_function_def._get_op_def(op)
914    if op._is_stateful and op not in self._allowlisted_stateful_ops:
915      raise ValueError(f"Cannot capture a stateful node (name:{op.name}, "
916                       f"type:{op.type}) by value.")
917    elif op.type in ("Placeholder", "PlaceholderV2"):
918      raise ValueError(f"Cannot capture a placeholder (name:{op.name}, "
919                       f"type:{op.type}) by value.")
920    # pylint: enable=protected-access
921
922    captured_inputs = [self._add_tensor_and_parents(x) for x in op.inputs]
923
924    captured_op = self._create_op_internal(
925        op.type,
926        captured_inputs, [o.dtype for o in op.outputs],
927        name=op.name,
928        attrs=op.node_def.attr,
929        op_def=op_def)
930
931    for t, captured_t in zip(op.outputs, captured_op.outputs):
932      self._captured[t.ref()] = captured_t
933
934    return captured_op
935
936
937def func_graph_from_py_func(func,
938                            arg_names,
939                            arg_types,
940                            name=None,
941                            capture_by_value=False,
942                            device=None,
943                            colocation_stack=None,
944                            container=None,
945                            collections_ref=None,
946                            arg_shapes=None,
947                            allowlisted_stateful_ops=None,
948                            capture_resource_var_by_value=True):
949  """Returns a _FuncGraph generated from `func`.
950
951  Args:
952    func: A Python callable which constructs a TF function body. The arguments
953      must correspond to `arg_types`. Returns a value or list/tuple of values.
954      No returned value can be None.
955    arg_names: A sequence of strings for the function argument names.
956    arg_types: A sequence of the function's argument types.
957    name: The function name. If None, the name is derived from `func`.
958    capture_by_value: boolean. If True, captured values will be copied into the
959      function body.
960    device: device name or function.
961    colocation_stack: A colocation stack (list) the _FuncGraph should use.
962    container: A container name the _FuncGraph should start with.
963    collections_ref: A reference to a collections dict the _FuncGraph should
964      use internally.
965    arg_shapes: A sequence of the function's argument shapes.
966    allowlisted_stateful_ops: A set of ops that if stateful we ignore and
967      re-create.
968    capture_resource_var_by_value: Boolean (defaults to True). If False,
969      captured resource variable returns the handle instead of value.
970
971  Returns:
972    A _FuncGraph.
973
974  Raises:
975    ValueError: if func returns None.
976  """
977  if not name:
978    name = function_utils.get_func_name(func)
979  func_graph = _FuncGraph(name, capture_by_value, allowlisted_stateful_ops,
980                          capture_resource_var_by_value)
981
982  with func_graph.as_default(), ops.device(device):
983    # pylint: disable=protected-access
984    if collections_ref is not None:
985      func_graph._collections = collections_ref
986    if container is not None:
987      func_graph._container = container
988    if colocation_stack is not None:
989      func_graph._colocation_stack = colocation_stack
990    # pylint: enable=protected-access
991
992    if arg_shapes is None:
993      arg_shapes = [None] * len(arg_types)
994
995    # Create placeholders for the function arguments.
996    for (argname, argtype, argshape) in zip(arg_names, arg_types, arg_shapes):
997      argholder = array_ops.placeholder(argtype, shape=argshape, name=argname)
998      func_graph.inputs.append(argholder)
999    # Call func and gather the output tensors.
1000    with vs.variable_scope("", custom_getter=func_graph.getvar):
1001      outputs = func(*func_graph.inputs)
1002
1003    # There is no way of distinguishing between a function not returning
1004    # anything and a function returning None in Python.
1005    # We need to allow the former and ideally want to forbid the latter as
1006    # it is most likely user error.
1007    # TODO(iga): Consider adding a @NoOutput decorator on top of @Defun to
1008    # allow users to explicitly mark the function as not returning anything.
1009    # For now, we allow a single None return and interpret it as a function
1010    # with no output.
1011    if outputs is None:
1012      outputs = []
1013    else:
1014      # If func only returned one value, make it a tuple.
1015      if not isinstance(outputs, (list, tuple)):
1016        outputs = (outputs,)
1017      if any(_ is None for _ in outputs):
1018        raise ValueError(f"Function {name} can not return None.")
1019    # Ensures each output is a Tensor in the function graph.
1020    outputs = [ops.convert_to_tensor(t) for t in outputs]
1021    outputs = [func_graph.capture(t) if t.graph is not func_graph else t
1022               for t in outputs]
1023    func_graph.outputs = outputs
1024  return func_graph
1025
1026
1027def _is_guaranteed_const(tensor):
1028  """Determines whether `tensor` is guaranteed to be a constant.
1029
1030  A tensor is guaranteed to be a constant if either it was produced by
1031  a `GuaranteeConst` op or if all of its children are guaranteed to be
1032  constants.
1033
1034  Args:
1035    tensor: The tensor for which to determine const-ness.
1036
1037  Returns:
1038    True if `tensor` is guaranteed to be a constant, False otherwise.
1039  """
1040
1041  if isinstance(tensor, ops.EagerTensor):
1042    return False
1043
1044  class Work(object):
1045
1046    def __init__(self, op, leaving):
1047      self.op = op
1048      self.leaving = leaving
1049
1050  is_guaranteed_const = lambda op: op.node_def.op == "GuaranteeConst"
1051  constants = set([])
1052  def all_inputs_const(op):
1053    # If all inputs of an op are guaranteed constants, then we can infer that
1054    # the op produces a constant as well.
1055    return op.inputs and all(inp.op in constants for inp in op.inputs)
1056
1057  visited = set([])
1058  stack = [Work(tensor.op, leaving=False)]
1059  while stack:
1060    work = stack.pop()
1061    if work.leaving:
1062      if all_inputs_const(work.op):
1063        constants.add(work.op)
1064      continue
1065    visited.add(work.op)
1066    if is_guaranteed_const(work.op):
1067      constants.add(work.op)
1068      continue
1069
1070    # This op will be revisited after all its inputs are checked for const-ness.
1071    stack.append(Work(work.op, leaving=True))
1072    for inp in work.op.inputs:
1073      if inp.op not in visited:
1074        stack.append(Work(inp.op, leaving=False))
1075  return tensor.op in constants
1076
1077
1078def _call(sig, *inputs, **kwargs):
1079  """Adds a node calling a function.
1080
1081  This adds a `call` op to the default graph that calls the function
1082  of signature `sig`, passing the tensors in `inputs` as arguments.
1083  It returns the outputs of the call, which are one or more tensors.
1084
1085  `sig` is OpDefArg.a `_DefinedFunction` object.
1086
1087  You can pass an optional keyword parameter `name=string` to name the
1088  added operation.
1089
1090  You can pass an optional keyword parameter `noinline=True|False` to
1091  instruct the runtime not to inline the function body into the call
1092  site.
1093
1094  Args:
1095    sig: OpDefArg. The signature of the function.
1096    *inputs: arguments to the function.
1097    **kwargs: Optional keyword arguments.  Can only contain 'name' or
1098        'noinline'.
1099
1100  Returns:
1101     A 2-element tuple. First element: a Tensor if the function returns a single
1102     value; a list of Tensors if the function returns multiple value; the
1103     Operation if the function returns no values. Second element: the Operation.
1104
1105  Raises:
1106    ValueError: if the arguments are invalid.
1107  """
1108  if len(inputs) != len(sig.input_arg):
1109    raise ValueError(f"Expected {len(sig.input_arg):d} arguments, got "
1110                     f"{len(inputs):d}.")
1111  name = kwargs.pop("name", None)
1112  g = ops.get_default_graph()
1113  func_name = sig.name
1114  if name is None:
1115    name = func_name
1116  attrs = _parse_kwargs_as_attrs(func_name, **kwargs)
1117  output_types = [dtypes.DType(x.type) for x in sig.output_arg]
1118  op = g._create_op_internal(  # pylint: disable=protected-access
1119      func_name, list(inputs), output_types, name=name, attrs=attrs, op_def=sig)
1120  if op.outputs:
1121    if len(op.outputs) == 1:
1122      ret = op.outputs[0]
1123    else:
1124      ret = tuple(op.outputs)
1125  else:
1126    ret = op
1127  return ret, op
1128
1129
1130def _from_definition(fdef, grad_func=None):
1131  """Creates a _DefinedFunction initialized from a FunctionDef proto.
1132
1133  Args:
1134    fdef: a FunctionDef
1135    grad_func: a _DefinedFunction or None
1136
1137  Returns:
1138    A _DefinedFunction representing fdef
1139  """
1140  # TODO(iga): This method does major surgery on _DefinedFunction.
1141  # Make it a named constructor using @classmethod of _DefinedFunction.
1142
1143  # The Python callable is only needed to create a FunctionDef. Since we have
1144  # the FunctionDef here, we don't need to set _DefinedFunction._func (nor do we
1145  # have access to such a callable here).
1146  func = None
1147  argnames = [arg.name for arg in fdef.signature.input_arg]
1148  input_types = tuple(
1149      dtypes.as_dtype(arg.type) for arg in fdef.signature.input_arg)
1150  func_name = fdef.signature.name
1151  # Note: FunctionDefs do not include python gradient functions, so if the
1152  # original _DefinedFunction included one it will not be reflected here.
1153  python_grad_func = None
1154  out_names = [arg.name for arg in fdef.signature.output_arg]
1155  result = _DefinedFunction(func, argnames, input_types, func_name, grad_func,
1156                            python_grad_func, out_names)
1157  # pylint: disable=protected-access
1158  serialized = fdef.SerializeToString()
1159  c_func = c_api.TF_FunctionImportFunctionDef(serialized)
1160  result._c_func = c_api_util.ScopedTFFunction(c_func, func_name)
1161  result._extra_inputs = []
1162  result._op_def = fdef.signature
1163  # pylint: enable=protected-access
1164
1165  return result
1166
1167
1168def from_library(lib):
1169  """Creates _DefinedFunctions initialized from a FunctionDefLibrary proto.
1170
1171  This method handles assigning the correct gradient functions to each
1172  function.
1173
1174  Args:
1175    lib: a FunctionDefLibrary
1176
1177  Returns:
1178    A list of _DefinedFunctions
1179
1180  Raises:
1181    ValueError: `lib` is invalid
1182  """
1183  if not lib.function and not lib.gradient:
1184    return []
1185
1186  # function name -> FunctionDef proto
1187  funcs = {fdef.signature.name: fdef for fdef in lib.function}
1188
1189  # Validate that all references function names have function defs
1190  for g in lib.gradient:
1191    if g.function_name not in funcs:
1192      raise ValueError(f"FunctionDefLibrary missing '{g.function_name}' "
1193                       f"FunctionDef\n{lib}")
1194    if g.gradient_func not in funcs:
1195      raise ValueError(f"FunctionDefLibrary missing '{g.gradient_func}' "
1196                       f"FunctionDef\n{lib}")
1197
1198  # function name -> gradient function name
1199  func_to_grad = collections.defaultdict(lambda: None)
1200  # gradient function name -> names of functions having that grad function
1201  grad_to_funcs = collections.defaultdict(list)
1202
1203  for gdef in lib.gradient:
1204    func_to_grad[gdef.function_name] = gdef.gradient_func
1205    grad_to_funcs[gdef.gradient_func].append(gdef.function_name)
1206
1207  # Start with functions without gradients
1208  ready = [
1209      fdef for fdef in lib.function if func_to_grad[fdef.signature.name] is None
1210  ]
1211  if not ready:
1212    raise ValueError(
1213        f"FunctionDefLibrary contains cyclic gradient functions!\n{lib}")
1214  # function name -> _DefinedFunction
1215  initialized = {}
1216
1217  while ready:
1218    fdef = ready.pop()
1219    name = fdef.signature.name
1220
1221    grad = initialized.get(func_to_grad[name])
1222    if func_to_grad[name]:
1223      assert grad
1224    defined_func = _from_definition(fdef, grad_func=grad)
1225    initialized[name] = defined_func
1226
1227    ready.extend(funcs[f] for f in grad_to_funcs[name])
1228
1229  return initialized.values()
1230
1231
1232def _get_experimental_kwarg_as_attr(attr_name, value):
1233  """Creates an AttrValue for a python object."""
1234  if isinstance(value, bool):
1235    return attr_value_pb2.AttrValue(b=value)
1236  elif isinstance(value, int):
1237    return attr_value_pb2.AttrValue(i=value)
1238  elif isinstance(value, float):
1239    return attr_value_pb2.AttrValue(f=value)
1240  elif isinstance(value, str):
1241    return attr_value_pb2.AttrValue(s=compat.as_bytes(value))
1242  else:
1243    raise ValueError(f"Attribute {attr_name} must be bool, int, float, or "
1244                     f"str. Got {type(value)}.")
1245
1246
1247def _get_kwarg_as_str_attr(attr_name, value):
1248  """Creates an AttrValue for a python object."""
1249  if isinstance(value, str):
1250    return attr_value_pb2.AttrValue(s=compat.as_bytes(value))
1251  else:
1252    raise ValueError(f"Attribute {attr_name} must be str. Got {type(value)}.")
1253
1254
1255def _parse_kwargs_as_attrs(func_name, **kwargs):
1256  """Parses **kwargs into a node's attributes."""
1257  attrs = {}
1258
1259  noinline = kwargs.pop("noinline", None)
1260  if noinline is not None:
1261    attrs["_noinline"] = attr_value_pb2.AttrValue(b=bool(noinline))
1262
1263  # For compatibility with previous behavior, Defun does not perform shape
1264  # inference through its function call operations.
1265  attrs["_disable_call_shape_inference"] = attr_value_pb2.AttrValue(b=True)
1266
1267  compiled = kwargs.pop("compiled", None)
1268  separate_compiled_gradients = kwargs.pop("separate_compiled_gradients", None)
1269  if compiled is not None:
1270    attrs["_XlaCompile"] = attr_value_pb2.AttrValue(b=bool(compiled))
1271    attrs["_XlaSeparateCompiledGradients"] = attr_value_pb2.AttrValue(
1272        b=bool(separate_compiled_gradients))
1273    # Forward _XlaScope from enclosing context (if set), otherwise create new.
1274    # pylint: disable=protected-access
1275    if "_XlaScope" in ops.get_default_graph()._attr_scope_map:
1276      attrs["_XlaScope"] = ops.get_default_graph()._attr_scope_map["_XlaScope"]
1277    else:
1278      attrs["_XlaScope"] = attr_value_pb2.AttrValue(
1279          s=("function_%s" % func_name).encode())
1280    # pylint: enable=protected-access
1281
1282  kwargs_keys = list(kwargs.keys())
1283  for key in kwargs_keys:
1284    if key.startswith("experimental_"):
1285      attrs[key] = _get_experimental_kwarg_as_attr(key, kwargs[key])
1286      del kwargs[key]
1287    # Support for https://github.com/tensorflow/community/pull/113/files.
1288    elif key == "_implements" or key == "_reference":
1289      attrs[key] = _get_kwarg_as_str_attr(key, kwargs[key])
1290      del kwargs[key]
1291  if kwargs:
1292    raise ValueError(f"Unknown keyword arguments: {kwargs.keys()}.")
1293  return attrs
1294
1295
1296def get_extra_vars():
1297  """Returns the captured variables by the function.
1298
1299  Returns:
1300    If the default graph is being used to define a function, the
1301    returned list of variables are those created inside the function
1302    body so far. Otherwise, returns an empty list.
1303  """
1304  g = ops.get_default_graph()
1305  if isinstance(g, _FuncGraph):
1306    return g.extra_vars
1307  else:
1308    return []
1309
1310
1311def get_extra_inputs():
1312  """Returns the captured input tensors by the function.
1313
1314  Returns:
1315    If the default graph is being used to define a function, the
1316    returned list of tensors are those accessed inside the function body
1317    but defined outside the function body so far. Otherwise, returns an
1318    empty list.
1319  """
1320  g = ops.get_default_graph()
1321  if isinstance(g, _FuncGraph):
1322    return g.extra_inputs
1323  else:
1324    return []
1325
1326
1327def get_extra_args():
1328  """Returns the corresponding function arguments for the captured inputs.
1329
1330  Returns:
1331    If the default graph is being used to define a function, the
1332    returned list of place holders are those used inside the function
1333    body corresponding those returned by get_extra_inputs(). Otherwise,
1334    returns an empty list.
1335  """
1336  g = ops.get_default_graph()
1337  if isinstance(g, _FuncGraph):
1338    return g.extra_args
1339  else:
1340    return []
1341
1342
1343def _type_list_to_str(types):
1344  if any(_ not in _DTYPE_TO_STR for _ in types):
1345    unsupported_types = [type_ for type_ in types if type_ not in _DTYPE_TO_STR]
1346    raise ValueError(f"Unsupported dtypes {unsupported_types} in "
1347                     "`types`. Supported dtypes are "
1348                     f"{_DTYPE_TO_STR.keys()}.")
1349  return "".join(_DTYPE_TO_STR[_] for _ in types)
1350
1351
1352# NOTE: The list needs to be extended when more data types are added.
1353_DTYPE_TO_STR = {
1354    dtypes.float16: "f16",
1355    dtypes.float32: "f32",
1356    dtypes.float64: "f64",
1357    dtypes.int32: "i32",
1358    dtypes.uint8: "i8",
1359    dtypes.uint16: "u16",
1360    dtypes.uint32: "u32",
1361    dtypes.uint64: "u64",
1362    dtypes.int16: "i16",
1363    dtypes.int8: "i8",
1364    dtypes.string: "s",
1365    dtypes.complex64: "c64",
1366    dtypes.complex128: "c128",
1367    dtypes.int64: "i64",
1368    dtypes.bool: "b",
1369    dtypes.qint8: "qi8",
1370    dtypes.quint8: "qu8",
1371    dtypes.qint16: "qi16",
1372    dtypes.quint16: "qu16",
1373    dtypes.qint32: "qi32",
1374    dtypes.bfloat16: "b16"
1375}
1376