xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/op_callbacks.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Unified callbacks op execution and creation under eager and graph modes."""
16
17from tensorflow.python.eager import context
18from tensorflow.python.eager import execute
19
20
21def add_op_callback(callback_fn):
22  r"""Add a thread-local callback that intercepts op execution and op creation.
23
24  The `callback_fn` will be invoked immediately after any of the three types
25  of events:
26    - The execution of an TensorFlow operation ("op" for short hereafter)
27      under eager mode,
28    - The execution of a FuncGraph under eager mode,
29    - The creation of an op during graph construction (e.g., in
30      @tf.function-decorated Python functions).
31
32  Known limitations:
33    1. Under graph mode, overriding the output tensors of control-flow ops,
34       including "If", "StatelessIf" and "While", may cause errors
35       (b/139668453). Overriding other tensors in a graph consisting of such
36       control-flow ops is okay.
37    2. Under eager mode, calling eager ops from the callback function itself
38       may lead to recursion stack overflow. This can be prevented by
39       returning from the callback function immediately on encountering the
40       op type involved (b/140334369).
41
42  Args:
43    callback_fn: A callback_fn that has the following signature:
44      def callback_fn(op_type,
45                      inputs,
46                      attrs,
47                      outputs,
48                      op_name=None,
49                      graph=None):
50        # op_type: The type of the op, as a string. E.g., "MatMul".
51        #          For the special case of FuncGraph execution, op_type
52        #          takes the name of the graph name, e.g.,
53        #          "__inference_my_func_24".
54        # inputs: (`tuple` of `Tensor`s) Input tensors to the op or the
55        #         FuncGraph.
56        #         - In eager execution, these are `EagerTensor`s.
57        #         - In graph construction, these are non-eager `Tensor`s
58        #           that form the inputs to the just-created op.
59        # attrs: The attributes of the op or FuncGraph of which the execution
60        #        or creation caused the current invocation of the callback.
61        #        This is applicable to both eager- and graph-based execution,
62        #        as well as graph construction.
63        #        This is a tuple of alternating attribute keys and attribute
64        #        values. E.g., `('adjoint_a', False, 'adjoint_b', False)`.
65        # outputs: (`tuple of `Tensor`s) Output tensors from the op or
66        #          FuncGraph.
67        #          In eager execution, these are `EagerTensor`s.
68        #          In graph construction, these are non-eager `Tensor`s that
69        #          are the outputs of the just-created op.
70        # op_name: Name of the op.
71        #          - If the current invocation of the callback is due to the
72        #            eager execution of an op or FuncGraph, this will be
73        #            `None`, as op names are meaningless in eager execution.
74        #          - In graph construction, this is the name of the op, e.g.,
75        #            "MatMul_2".
76        # graph: The graph that the op belongs to (if any).
77        #        - In eager execution of an op or FuncGraph, this is `None`.
78        #        - In graph construction, this is the op's enclosing graph
79        #          as a `tf.Graph` object.
80        #
81        # Return values:
82        #   This callback function is expected to return `None` or
83        #   a `list` or `tuple` of `Tensor`s with its length matching
84        #   `len(outputs)`, in the order that corresponds to that of the
85        #   `outputs` argument.
86        #   If the return value is `None`, downstream execution or graph
87        #   construction will be unaffected.
88        #   However, if the return value is a `list` or `tuple` of `Tensor`s,
89        #   - In eager execution, these returned `Tensor`s should be
90        #     `EagerTensor`s. Their values will replace the original values of
91        #     `outputs` for downstream eager execution. (*Not implemented yet*).
92        #   - In graph construction, these returned `Tensor`s should be
93        #     non-eager `Tensor`s. Their values will replace the original
94        #     `outputs` for downstream graph construction.
95
96  Raises:
97    ValueEror: If `callback_fn` is `None` or not callable.
98  """
99  # TODO(b/139668041): Implement support for overriding `EagerTensor`s from
100  # callback.
101  if callback_fn is None:
102    raise ValueError("Passed callback function cannot be None.")
103  if not callable(callback_fn):
104    raise ValueError(
105        "Callback function passed to op_callback() is expected to be callable, "
106        f"but got {callback_fn} of type {type(callback_fn)}.")
107  ctx = context.context()
108  ctx.add_op_callback(callback_fn)
109  if ctx.executing_eagerly():
110    # Monkey-patch `execute.execute()`.
111    execute.execute = execute.execute_with_callbacks
112
113
114def should_invoke_op_callbacks():
115  """Determine if op callbacks are present and should be invoked.
116
117  Returns:
118    A thread-local result (boolean) indicating whether any op callback(s) exist
119    and should be invoked.
120  """
121  ctx = context.context()
122  return ctx.op_callbacks and not ctx.invoking_op_callbacks
123
124
125def remove_op_callback(op_callback):
126  """Remove an already-added op callback.
127
128  Args:
129    op_callback: The op callback to be removed.
130
131  Raises:
132    KeyError: If `op_callback` has not been registered using `add_op_callback()`
133      before.
134  """
135  ctx = context.context()
136  ctx.remove_op_callback(op_callback)
137  if ctx.executing_eagerly() and not ctx.op_callbacks:
138    # Undo monkey-patch of execute.execute if there are no more callbacks.
139    execute.execute = execute.quick_execute
140
141
142def clear_op_callbacks():
143  """Clear all op callbacks registered in the current thread."""
144  for callback in context.context().op_callbacks:
145    remove_op_callback(callback)
146
147
148def invoke_op_callbacks(op_type,
149                        inputs,
150                        attrs,
151                        outputs,
152                        op_name=None,
153                        graph=None):
154  r"""Invoke the callbacks that exist in the current scope (if any).
155
156  If no callbacks are present in the current scope, this method returns
157  immediately.
158
159  Args:
160    op_type: Type of the operation (e.g., "MatMul").
161    inputs: Input tensors to the op. These are `EagerTensor`s in the case of
162      eager execution of ops or `FuncGraph`s, and are non-eager `Tensor`s in the
163      case of graph construction.
164    attrs: Attributes of the op, as `tuple` of alternating keys and values.
165    outputs: Output tensors from the op. These are `EagerTensor`s in the case of
166      eager execution and are non-eager `Tensor`s in the case of graph
167      construction.
168    op_name: Name of the op. Applicable if and only if this method is invoked
169      due to the graph construction of an op or the eager execution of a
170      `FuncGraph`.
171    graph: The graph involved (if any).
172      - In the case if the eager execution of an op or FuncGraph, this is
173        `None`.
174      - In the case of the graph construction of an op, this is the `tf.Graph`
175        object being built.
176
177  Returns:
178    `None`, or a `list` or `tuple` of output tenors that will override the
179    original (input) `outputs`.
180  """
181  ctx = context.context()
182  if ctx.op_callbacks:
183    # Guards against stack overflow that can result from recursive invocation
184    # due to op constructions inside client-supplied op callbacks.
185    ctx.invoking_op_callbacks = True
186    try:
187      if isinstance(attrs, dict):
188        attrs_list = []
189        for key in attrs:
190          attrs_list.append(key)
191          attrs_list.append(attrs[key])
192        attrs_tuple = tuple(attrs_list)
193      else:
194        attrs_tuple = attrs
195
196      new_outputs = outputs
197      for callback in ctx.op_callbacks:
198        new_outputs = callback(
199            op_type,
200            inputs,
201            attrs_tuple,
202            new_outputs,
203            op_name=op_name,
204            graph=graph)
205        if new_outputs is not None and len(new_outputs) != len(outputs):
206          raise ValueError(
207              f"The op callback returned {len(new_outputs)} tensors, which "
208              f"does not match the original number of outputs of op {op_name} "
209              f"({len(outputs)}).")
210      return new_outputs
211    finally:
212      ctx.invoking_op_callbacks = False
213  else:
214    return outputs
215