1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Unified callbacks op execution and creation under eager and graph modes.""" 16 17from tensorflow.python.eager import context 18from tensorflow.python.eager import execute 19 20 21def add_op_callback(callback_fn): 22 r"""Add a thread-local callback that intercepts op execution and op creation. 23 24 The `callback_fn` will be invoked immediately after any of the three types 25 of events: 26 - The execution of an TensorFlow operation ("op" for short hereafter) 27 under eager mode, 28 - The execution of a FuncGraph under eager mode, 29 - The creation of an op during graph construction (e.g., in 30 @tf.function-decorated Python functions). 31 32 Known limitations: 33 1. Under graph mode, overriding the output tensors of control-flow ops, 34 including "If", "StatelessIf" and "While", may cause errors 35 (b/139668453). Overriding other tensors in a graph consisting of such 36 control-flow ops is okay. 37 2. Under eager mode, calling eager ops from the callback function itself 38 may lead to recursion stack overflow. This can be prevented by 39 returning from the callback function immediately on encountering the 40 op type involved (b/140334369). 41 42 Args: 43 callback_fn: A callback_fn that has the following signature: 44 def callback_fn(op_type, 45 inputs, 46 attrs, 47 outputs, 48 op_name=None, 49 graph=None): 50 # op_type: The type of the op, as a string. E.g., "MatMul". 51 # For the special case of FuncGraph execution, op_type 52 # takes the name of the graph name, e.g., 53 # "__inference_my_func_24". 54 # inputs: (`tuple` of `Tensor`s) Input tensors to the op or the 55 # FuncGraph. 56 # - In eager execution, these are `EagerTensor`s. 57 # - In graph construction, these are non-eager `Tensor`s 58 # that form the inputs to the just-created op. 59 # attrs: The attributes of the op or FuncGraph of which the execution 60 # or creation caused the current invocation of the callback. 61 # This is applicable to both eager- and graph-based execution, 62 # as well as graph construction. 63 # This is a tuple of alternating attribute keys and attribute 64 # values. E.g., `('adjoint_a', False, 'adjoint_b', False)`. 65 # outputs: (`tuple of `Tensor`s) Output tensors from the op or 66 # FuncGraph. 67 # In eager execution, these are `EagerTensor`s. 68 # In graph construction, these are non-eager `Tensor`s that 69 # are the outputs of the just-created op. 70 # op_name: Name of the op. 71 # - If the current invocation of the callback is due to the 72 # eager execution of an op or FuncGraph, this will be 73 # `None`, as op names are meaningless in eager execution. 74 # - In graph construction, this is the name of the op, e.g., 75 # "MatMul_2". 76 # graph: The graph that the op belongs to (if any). 77 # - In eager execution of an op or FuncGraph, this is `None`. 78 # - In graph construction, this is the op's enclosing graph 79 # as a `tf.Graph` object. 80 # 81 # Return values: 82 # This callback function is expected to return `None` or 83 # a `list` or `tuple` of `Tensor`s with its length matching 84 # `len(outputs)`, in the order that corresponds to that of the 85 # `outputs` argument. 86 # If the return value is `None`, downstream execution or graph 87 # construction will be unaffected. 88 # However, if the return value is a `list` or `tuple` of `Tensor`s, 89 # - In eager execution, these returned `Tensor`s should be 90 # `EagerTensor`s. Their values will replace the original values of 91 # `outputs` for downstream eager execution. (*Not implemented yet*). 92 # - In graph construction, these returned `Tensor`s should be 93 # non-eager `Tensor`s. Their values will replace the original 94 # `outputs` for downstream graph construction. 95 96 Raises: 97 ValueEror: If `callback_fn` is `None` or not callable. 98 """ 99 # TODO(b/139668041): Implement support for overriding `EagerTensor`s from 100 # callback. 101 if callback_fn is None: 102 raise ValueError("Passed callback function cannot be None.") 103 if not callable(callback_fn): 104 raise ValueError( 105 "Callback function passed to op_callback() is expected to be callable, " 106 f"but got {callback_fn} of type {type(callback_fn)}.") 107 ctx = context.context() 108 ctx.add_op_callback(callback_fn) 109 if ctx.executing_eagerly(): 110 # Monkey-patch `execute.execute()`. 111 execute.execute = execute.execute_with_callbacks 112 113 114def should_invoke_op_callbacks(): 115 """Determine if op callbacks are present and should be invoked. 116 117 Returns: 118 A thread-local result (boolean) indicating whether any op callback(s) exist 119 and should be invoked. 120 """ 121 ctx = context.context() 122 return ctx.op_callbacks and not ctx.invoking_op_callbacks 123 124 125def remove_op_callback(op_callback): 126 """Remove an already-added op callback. 127 128 Args: 129 op_callback: The op callback to be removed. 130 131 Raises: 132 KeyError: If `op_callback` has not been registered using `add_op_callback()` 133 before. 134 """ 135 ctx = context.context() 136 ctx.remove_op_callback(op_callback) 137 if ctx.executing_eagerly() and not ctx.op_callbacks: 138 # Undo monkey-patch of execute.execute if there are no more callbacks. 139 execute.execute = execute.quick_execute 140 141 142def clear_op_callbacks(): 143 """Clear all op callbacks registered in the current thread.""" 144 for callback in context.context().op_callbacks: 145 remove_op_callback(callback) 146 147 148def invoke_op_callbacks(op_type, 149 inputs, 150 attrs, 151 outputs, 152 op_name=None, 153 graph=None): 154 r"""Invoke the callbacks that exist in the current scope (if any). 155 156 If no callbacks are present in the current scope, this method returns 157 immediately. 158 159 Args: 160 op_type: Type of the operation (e.g., "MatMul"). 161 inputs: Input tensors to the op. These are `EagerTensor`s in the case of 162 eager execution of ops or `FuncGraph`s, and are non-eager `Tensor`s in the 163 case of graph construction. 164 attrs: Attributes of the op, as `tuple` of alternating keys and values. 165 outputs: Output tensors from the op. These are `EagerTensor`s in the case of 166 eager execution and are non-eager `Tensor`s in the case of graph 167 construction. 168 op_name: Name of the op. Applicable if and only if this method is invoked 169 due to the graph construction of an op or the eager execution of a 170 `FuncGraph`. 171 graph: The graph involved (if any). 172 - In the case if the eager execution of an op or FuncGraph, this is 173 `None`. 174 - In the case of the graph construction of an op, this is the `tf.Graph` 175 object being built. 176 177 Returns: 178 `None`, or a `list` or `tuple` of output tenors that will override the 179 original (input) `outputs`. 180 """ 181 ctx = context.context() 182 if ctx.op_callbacks: 183 # Guards against stack overflow that can result from recursive invocation 184 # due to op constructions inside client-supplied op callbacks. 185 ctx.invoking_op_callbacks = True 186 try: 187 if isinstance(attrs, dict): 188 attrs_list = [] 189 for key in attrs: 190 attrs_list.append(key) 191 attrs_list.append(attrs[key]) 192 attrs_tuple = tuple(attrs_list) 193 else: 194 attrs_tuple = attrs 195 196 new_outputs = outputs 197 for callback in ctx.op_callbacks: 198 new_outputs = callback( 199 op_type, 200 inputs, 201 attrs_tuple, 202 new_outputs, 203 op_name=op_name, 204 graph=graph) 205 if new_outputs is not None and len(new_outputs) != len(outputs): 206 raise ValueError( 207 f"The op callback returned {len(new_outputs)} tensors, which " 208 f"does not match the original number of outputs of op {op_name} " 209 f"({len(outputs)}).") 210 return new_outputs 211 finally: 212 ctx.invoking_op_callbacks = False 213 else: 214 return outputs 215