1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Utilities for V2 control flow.""" 17 18from tensorflow.core.framework import attr_value_pb2 19from tensorflow.python.distribute import distribution_strategy_context 20from tensorflow.python.eager import context 21from tensorflow.python.eager import function 22from tensorflow.python.framework import function_def_to_graph 23from tensorflow.python.framework import ops 24from tensorflow.python.framework.func_graph import FuncGraph 25from tensorflow.python.ops import control_flow_util 26from tensorflow.python.ops import control_flow_v2_func_graphs 27from tensorflow.python.ops import gradients_util 28from tensorflow.python.util import keras_deps 29from tensorflow.python.util import tf_contextlib 30 31 32_EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = None 33_DISABLE_LOWER_USING_SWITCH_MERGE = False 34 35 36CondBranchFuncGraph = control_flow_v2_func_graphs.CondBranchFuncGraph 37WhileCondFuncGraph = control_flow_v2_func_graphs.WhileCondFuncGraph 38WhileBodyFuncGraph = control_flow_v2_func_graphs.WhileBodyFuncGraph 39 40 41def in_defun(): 42 """Returns if the current graph is, or is nested in, a defun.""" 43 if context.executing_eagerly(): return False 44 45 graph = ops.get_default_graph() 46 while (isinstance(graph, CondBranchFuncGraph) or 47 isinstance(graph, WhileBodyFuncGraph) or 48 isinstance(graph, WhileCondFuncGraph)): 49 graph = graph.outer_graph 50 return isinstance(graph, FuncGraph) 51 52 53def in_while_loop_defun(graph): 54 """Returns if the graph is a while loop FuncGraph.""" 55 if context.executing_eagerly(): return False 56 return (isinstance(graph, WhileCondFuncGraph) or 57 isinstance(graph, WhileBodyFuncGraph)) 58 59 60def create_new_tf_function(func_graph): 61 """Converts func_graph to a TF_Function and adds it to the current graph. 62 63 Args: 64 func_graph: FuncGraph 65 66 Returns: 67 The name of the new TF_Function. 68 """ 69 func = function._EagerDefinedFunction( # pylint: disable=protected-access 70 func_graph.name, func_graph, func_graph.inputs, func_graph.outputs, {}) 71 func.add_to_graph(func_graph.outer_graph) 72 return func_graph.name 73 74 75def unique_fn_name(scope, name): 76 """Returns a unique name to use for a control flow function. 77 78 Args: 79 scope: A name scope string. 80 name: An identifier for this function (e.g. "true", "body"). 81 82 Returns: 83 A string, the name to use for the function. 84 """ 85 return ("%s%s_%s" % (scope, name, ops.uid())).replace("/", "_") 86 87 88def unique_grad_fn_name(forward_name): 89 return "%s_grad_%s" % (forward_name, ops.uid()) 90 91 92def maybe_set_lowering_attr(op, lower_using_switch_merge=None): 93 """Sets the flag to enable lowering on `op` if necessary. 94 95 Lowering allows cond_v2 and while_v2 to avoid some of the limitations of 96 Functions, allowing users to specify devices & colocation inside of cond_v2 97 and while_v2 input functions, and enabling non-strict evaluation & partial 98 pruning. This brings v2 control flow closer to feature parity with v1 control 99 flow. 100 101 However, we do not lower in the following cases: 102 - When the `If` or `While` ops are in the XLA context. Because it is easier 103 for XLA to apply its own optimizations when dealing with un-lowered 104 control flow operators than with low-level control flow primitives. 105 - When the eager execution context specifies the executor of functions to 106 be the single threaded executor (see context.function_executor_type()). 107 Because the single threaded executor does not support v1 control flow ops. 108 - When 'lower_using_switch_merge' is explicitly set to False. 109 110 Args: 111 op: An `If` or `While` Operation. 112 lower_using_switch_merge: Explicit value to lower or not (optional). 113 """ 114 if lower_using_switch_merge is not None: 115 # pylint: disable=protected-access 116 op._set_attr("_lower_using_switch_merge", 117 attr_value_pb2.AttrValue(b=lower_using_switch_merge)) 118 # pylint: enable=protected-access 119 elif (not _DISABLE_LOWER_USING_SWITCH_MERGE and 120 not control_flow_util.GraphOrParentsInXlaContext(op.graph) and 121 context.context().function_call_options.executor_type != 122 "SINGLE_THREADED_EXECUTOR"): 123 # pylint: disable=protected-access 124 op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True)) 125 # pylint: enable=protected-access 126 127 128def maybe_propagate_compile_time_consts_in_xla(op): 129 """Tells XLA whether to propagate compile-time consts in the loop body. 130 131 This is needed to make compile time constants available to ops, for example 132 `max_num_elements` in `EmptyTensorList`, inside the loop body. Ideally this 133 would always be turned on, but that doesn't work with legacy functionalized 134 while_loops. 135 136 Args: 137 op: A `While` Operation. 138 """ 139 if control_flow_util.GraphOrParentsInXlaContext(op.graph): 140 # pylint: disable=protected-access 141 op._set_attr("_xla_propagate_compile_time_consts", 142 attr_value_pb2.AttrValue(b=True)) 143 # pylint: enable=protected-access 144 145 146def resource_input_index(tensor_name, input_names, node_defs, functions): 147 """Returns the index of the input corresponding to `tensor_name`. 148 149 This method is used to find the corresponding index of an arbitrary resource 150 tensor in a function (the function could be a loop body). We assume that 151 resource handles are never created in functions, so that every resource 152 tensor can be traced back to a function input. 153 154 The awkward signature of this method is to make it work with both FuncGraphs 155 and FunctionDefs. This is so we can recurse on function call ops without 156 building the corresponding FuncGraph (note that even if a FuncGraph for a 157 FunctionDef already exists, the input/output/node names may have been 158 changed when the FuncGraph was serialized to the FunctionDef, which makes it 159 unusable with this algorithm). 160 161 Args: 162 tensor_name: the name of the resource tensor to be resolved to an input. 163 input_names: a list of the names of all inputs to the function. 164 node_defs: a dict mapping op name -> NodeDef for every op in the function. 165 functions: a dict mapping function name -> _EagerDefinedFunction. 166 167 Returns: 168 The index into input_names corresponding to `tensor_name`. 169 """ 170 while tensor_name not in input_names: 171 # FunctionDefs and graphs use different tensor naming conventions. 172 parts = tensor_name.split(":") 173 if len(parts) == 3: 174 op_name, _, output_idx = parts 175 elif len(parts) == 2: 176 op_name, output_idx = parts 177 else: 178 assert len(parts) == 1 179 op_name = parts[0] 180 output_idx = 0 181 tensor_name = "%s:%d" % (tensor_name, output_idx) 182 # Check again for cases where the tensor suffix (":0") is stripped out. 183 if tensor_name in input_names: 184 break 185 output_idx = int(output_idx) 186 node_def = node_defs[op_name] 187 188 def _extract_input_index(function_attribute_name): 189 func_name = node_def.attr[function_attribute_name].func.name 190 fdef = functions[func_name].definition 191 output_arg_name = fdef.signature.output_arg[output_idx].name 192 output_tensor_name = fdef.ret[output_arg_name] 193 return resource_input_index( 194 output_tensor_name, [arg.name for arg in fdef.signature.input_arg], 195 {ndef.name: ndef for ndef in fdef.node_def}, functions) 196 197 if node_def.op in ("Identity", "While"): 198 # Captured resources occur at the same index in the lists of inputs and 199 # outputs of a while or identity op. So we lookup the input of `tensor.op` 200 # at the same index as the index of `tensor` in the `tensor.op.outputs`. 201 tensor_name = node_def.input[output_idx] 202 elif node_def.op in ("PartitionedCall", "StatefulPartitionedCall"): 203 # Functions output any captured resource tensors used by their 204 # gradients. `tensor_name` is one of these outputs from a nested 205 # function call, so recursively find the corresponding input in the 206 # nested FunctionDef. 207 tensor_name = node_def.input[_extract_input_index("f")] 208 elif node_def.op in ("If", "StatelessIf"): 209 input_index = _extract_input_index("then_branch") 210 if input_index != _extract_input_index("else_branch"): 211 raise AssertionError( 212 ("Expected cond branches ({} op) to each have the same " 213 "input->output mapping of resources.").format(node_def.op)) 214 tensor_name = node_def.input[ 215 # Ignore the `cond` input; the function inputs come after. 216 input_index + 1] 217 else: 218 # We assume there are no other ops types that will "forward" resource 219 # handles like this, so all other handles must have been created by the 220 # op. (Note that cond_v2 wraps resource handle outputs in optionals, 221 # which we'll end up accumulating). 222 raise ValueError("Taking gradient of a while loop which creates " 223 "a resource in its body is not supported: %s (%s)" 224 % (op_name, node_def.op)) 225 226 return input_names.index(tensor_name) 227 228 229@tf_contextlib.contextmanager 230def clear_control_inputs(): 231 """Clears the control inputs but preserves the ControlFlowContext. 232 233 This is needed to preserve the XLAControlFlowControl when clearing 234 control inputs for the gradient accumulators in while_v2. 235 `ops.control_dependencies` does not allow that. 236 237 Yields: 238 A context manager in which the ops created will not have any control inputs 239 by default but the control flow context is the same. 240 """ 241 # pylint: disable=protected-access 242 control_flow_context = ops.get_default_graph()._get_control_flow_context() 243 with ops.control_dependencies(None): 244 ops.get_default_graph()._set_control_flow_context(control_flow_context) 245 yield 246 # pylint: enable=protected-access 247 248 249def _is_tpu_strategy(strategy): 250 return (strategy is not None and 251 strategy.__class__.__name__.startswith("TPUStrategy")) 252 253 254def _is_building_keras_layer(): 255 # TODO(srbs): Remove this function when we no long support session with Keras. 256 keras_call_context_function = keras_deps.get_call_context_function() 257 if keras_call_context_function: 258 return keras_call_context_function().layer is not None 259 else: 260 return False 261 262 263def output_all_intermediates(): 264 """Whether to output all intermediates of a functional control flow op. 265 266 The default behavior is to output intermediates only when building a Keras 267 Layer in graph mode and that too when certain other conditions are met: 268 1. We do not output intermediates if the functional control flow op 269 is being built inside a FuncGraph which is not a If/While graph. This 270 guards against outputting intermediates in eager mode since keras adds 271 tensors to a FuncGraph named "keras_graph" in that case. Also because we 272 do not output intermediates of tf.function (since this feature is only for 273 backwards compatibility) outputting intermediates of functional control 274 flow ops built inside tf.function is of no value. 275 2. We do not output intermediates when the compilation is using XLA or for a 276 TPU. 277 3. We do not output intermediates when a single threaded executor is used 278 since that does not perform inlining and pruning. 279 280 Returns: 281 A bool telling whether to output all intermediates. 282 """ 283 if _EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE is not None: 284 return _EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE 285 if in_defun(): 286 return False 287 if (control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()) or 288 _is_tpu_strategy(distribution_strategy_context.get_strategy())): 289 return False 290 if (context.context().function_call_options.executor_type == 291 "SINGLE_THREADED_EXECUTOR"): 292 return False 293 return _is_building_keras_layer() 294 295 296def get_func_graph(op, input_shapes, func_name): 297 """Generates and returns a FuncGraph for the given op and input_shapes.""" 298 fdef = None 299 graph = op.graph 300 # Recursively search the func in graphs. 301 while graph is not None: 302 func = graph._get_function(func_name) # pylint: disable=protected-access 303 if func is not None: 304 fdef = func.definition 305 break 306 if hasattr(graph, "outer_graph"): 307 graph = graph.outer_graph 308 else: 309 break 310 311 if fdef is None: 312 raise KeyError("%s cannot be found in the graph" % func_name) 313 314 # `op.graph` may not be the same as `ops.get_default_graph()` e.g. 315 # in the case of nested if ops or when the gradient is being computed 316 # from inside a Defun. We build the `func_graph` with `op.graph` as its 317 # `outer_graph`. This resembles how the `FuncGraph` was built in the 318 # forward pass. We need this so that we can resolve references to tensors 319 # in `func_graph` from its gradient graph in `_resolve_grad_inputs`. 320 with op.graph.as_default(): 321 func_graph = function_def_to_graph.function_def_to_graph( 322 fdef, input_shapes) 323 return func_graph 324 325 326def get_op_and_outputs(op_or_outputs): 327 if isinstance(op_or_outputs, ops.Operation): 328 return op_or_outputs, [] 329 elif not op_or_outputs: # Empty list. 330 return None, [] 331 else: 332 return op_or_outputs[0].op, op_or_outputs 333 334 335def graph_wrapped_for_higher_order_tape_gradients(graph): 336 """Check if `graph` is wrapped by `run_as_function_for_tape_gradients`.""" 337 while graph is not None: 338 if "cflow_gradient_wrapper" in getattr(graph, "name", ""): 339 return True 340 graph = getattr(graph, "outer_graph", None) 341 return False 342 343 344def run_as_function_for_tape_gradients(make_op, inputs): 345 """Fix higher-order tape gradients by wrapping `make_op` in a function. 346 347 Args: 348 make_op: A function that takes a list of inputs and returns a list of output 349 tensors. This function should set any handle data relevant to its outputs 350 before returning. 351 inputs: A list of tensors to check for tape gradients and pass to 352 `make_op`. These should include all tensors used in `make_op`. 353 354 Returns: 355 Tensors corresponding to `make_op`'s output. 356 """ 357 # GradientTapes created inside a function currently don't work well with 358 # un-wrapped control flow ops in that same function. Wrapping in an extra 359 # layer of intermediate function means we run extra logic in the function 360 # gradient code to record the correct intermediates on the tape. 361 # 362 # The function attribute inputs to control flow ops are not hashable, so we 363 # pass everything as a capture to bypass defun's caching. 364 if (gradients_util.PossibleTapeGradientTypes(inputs) 365 == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER 366 # We only need one function between the tape and the op; if we've already 367 # wrapped once, we stop wrapping to avoid infinite recursion. 368 and not (ops.get_default_graph().building_function 369 and "cflow_gradient_wrapper" in ops.get_default_graph().name)): 370 results = function.defun_with_attributes( 371 make_op, 372 autograph=False, 373 attributes=dict(func_name="cflow_gradient_wrapper"))(inputs) 374 return results 375 else: 376 return make_op(inputs) 377