xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/control_flow_util_v2.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Utilities for V2 control flow."""
17
18from tensorflow.core.framework import attr_value_pb2
19from tensorflow.python.distribute import distribution_strategy_context
20from tensorflow.python.eager import context
21from tensorflow.python.eager import function
22from tensorflow.python.framework import function_def_to_graph
23from tensorflow.python.framework import ops
24from tensorflow.python.framework.func_graph import FuncGraph
25from tensorflow.python.ops import control_flow_util
26from tensorflow.python.ops import control_flow_v2_func_graphs
27from tensorflow.python.ops import gradients_util
28from tensorflow.python.util import keras_deps
29from tensorflow.python.util import tf_contextlib
30
31
32_EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = None
33_DISABLE_LOWER_USING_SWITCH_MERGE = False
34
35
36CondBranchFuncGraph = control_flow_v2_func_graphs.CondBranchFuncGraph
37WhileCondFuncGraph = control_flow_v2_func_graphs.WhileCondFuncGraph
38WhileBodyFuncGraph = control_flow_v2_func_graphs.WhileBodyFuncGraph
39
40
41def in_defun():
42  """Returns if the current graph is, or is nested in, a defun."""
43  if context.executing_eagerly(): return False
44
45  graph = ops.get_default_graph()
46  while (isinstance(graph, CondBranchFuncGraph) or
47         isinstance(graph, WhileBodyFuncGraph) or
48         isinstance(graph, WhileCondFuncGraph)):
49    graph = graph.outer_graph
50  return isinstance(graph, FuncGraph)
51
52
53def in_while_loop_defun(graph):
54  """Returns if the graph is a while loop FuncGraph."""
55  if context.executing_eagerly(): return False
56  return (isinstance(graph, WhileCondFuncGraph) or
57          isinstance(graph, WhileBodyFuncGraph))
58
59
60def create_new_tf_function(func_graph):
61  """Converts func_graph to a TF_Function and adds it to the current graph.
62
63  Args:
64    func_graph: FuncGraph
65
66  Returns:
67    The name of the new TF_Function.
68  """
69  func = function._EagerDefinedFunction(  # pylint: disable=protected-access
70      func_graph.name, func_graph, func_graph.inputs, func_graph.outputs, {})
71  func.add_to_graph(func_graph.outer_graph)
72  return func_graph.name
73
74
75def unique_fn_name(scope, name):
76  """Returns a unique name to use for a control flow function.
77
78  Args:
79    scope: A name scope string.
80    name: An identifier for this function (e.g. "true", "body").
81
82  Returns:
83    A string, the name to use for the function.
84  """
85  return ("%s%s_%s" % (scope, name, ops.uid())).replace("/", "_")
86
87
88def unique_grad_fn_name(forward_name):
89  return "%s_grad_%s" % (forward_name, ops.uid())
90
91
92def maybe_set_lowering_attr(op, lower_using_switch_merge=None):
93  """Sets the flag to enable lowering on `op` if necessary.
94
95  Lowering allows cond_v2 and while_v2 to avoid some of the limitations of
96  Functions, allowing users to specify devices & colocation inside of cond_v2
97  and while_v2 input functions, and enabling non-strict evaluation & partial
98  pruning. This brings v2 control flow closer to feature parity with v1 control
99  flow.
100
101  However, we do not lower in the following cases:
102    - When the `If` or `While` ops are in the XLA context. Because it is easier
103      for XLA to apply its own optimizations when dealing with un-lowered
104      control flow operators than with low-level control flow primitives.
105    - When the eager execution context specifies the executor of functions to
106      be the single threaded executor (see context.function_executor_type()).
107      Because the single threaded executor does not support v1 control flow ops.
108    - When 'lower_using_switch_merge' is explicitly set to False.
109
110  Args:
111    op: An `If` or `While` Operation.
112    lower_using_switch_merge: Explicit value to lower or not (optional).
113  """
114  if lower_using_switch_merge is not None:
115    # pylint: disable=protected-access
116    op._set_attr("_lower_using_switch_merge",
117                 attr_value_pb2.AttrValue(b=lower_using_switch_merge))
118    # pylint: enable=protected-access
119  elif (not _DISABLE_LOWER_USING_SWITCH_MERGE and
120        not control_flow_util.GraphOrParentsInXlaContext(op.graph) and
121        context.context().function_call_options.executor_type !=
122        "SINGLE_THREADED_EXECUTOR"):
123    # pylint: disable=protected-access
124    op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True))
125    # pylint: enable=protected-access
126
127
128def maybe_propagate_compile_time_consts_in_xla(op):
129  """Tells XLA whether to propagate compile-time consts in the loop body.
130
131  This is needed to make compile time constants available to ops, for example
132  `max_num_elements` in `EmptyTensorList`, inside the loop body. Ideally this
133  would always be turned on, but that doesn't work with legacy functionalized
134  while_loops.
135
136  Args:
137    op: A `While` Operation.
138  """
139  if control_flow_util.GraphOrParentsInXlaContext(op.graph):
140    # pylint: disable=protected-access
141    op._set_attr("_xla_propagate_compile_time_consts",
142                 attr_value_pb2.AttrValue(b=True))
143    # pylint: enable=protected-access
144
145
146def resource_input_index(tensor_name, input_names, node_defs, functions):
147  """Returns the index of the input corresponding to `tensor_name`.
148
149  This method is used to find the corresponding index of an arbitrary resource
150  tensor in a function (the function could be a loop body). We assume that
151  resource handles are never created in functions, so that every resource
152  tensor can be traced back to a function input.
153
154  The awkward signature of this method is to make it work with both FuncGraphs
155  and FunctionDefs. This is so we can recurse on function call ops without
156  building the corresponding FuncGraph (note that even if a FuncGraph for a
157  FunctionDef already exists, the input/output/node names may have been
158  changed when the FuncGraph was serialized to the FunctionDef, which makes it
159  unusable with this algorithm).
160
161  Args:
162    tensor_name: the name of the resource tensor to be resolved to an input.
163    input_names: a list of the names of all inputs to the function.
164    node_defs: a dict mapping op name -> NodeDef for every op in the function.
165    functions: a dict mapping function name -> _EagerDefinedFunction.
166
167  Returns:
168    The index into input_names corresponding to `tensor_name`.
169  """
170  while tensor_name not in input_names:
171    # FunctionDefs and graphs use different tensor naming conventions.
172    parts = tensor_name.split(":")
173    if len(parts) == 3:
174      op_name, _, output_idx = parts
175    elif len(parts) == 2:
176      op_name, output_idx = parts
177    else:
178      assert len(parts) == 1
179      op_name = parts[0]
180      output_idx = 0
181      tensor_name = "%s:%d" % (tensor_name, output_idx)
182      # Check again for cases where the tensor suffix (":0") is stripped out.
183      if tensor_name in input_names:
184        break
185    output_idx = int(output_idx)
186    node_def = node_defs[op_name]
187
188    def _extract_input_index(function_attribute_name):
189      func_name = node_def.attr[function_attribute_name].func.name
190      fdef = functions[func_name].definition
191      output_arg_name = fdef.signature.output_arg[output_idx].name
192      output_tensor_name = fdef.ret[output_arg_name]
193      return resource_input_index(
194          output_tensor_name, [arg.name for arg in fdef.signature.input_arg],
195          {ndef.name: ndef for ndef in fdef.node_def}, functions)
196
197    if node_def.op in ("Identity", "While"):
198      # Captured resources occur at the same index in the lists of inputs and
199      # outputs of a while or identity op. So we lookup the input of `tensor.op`
200      # at the same index as the index of `tensor` in the `tensor.op.outputs`.
201      tensor_name = node_def.input[output_idx]
202    elif node_def.op in ("PartitionedCall", "StatefulPartitionedCall"):
203      # Functions output any captured resource tensors used by their
204      # gradients.  `tensor_name` is one of these outputs from a nested
205      # function call, so recursively find the corresponding input in the
206      # nested FunctionDef.
207      tensor_name = node_def.input[_extract_input_index("f")]
208    elif node_def.op in ("If", "StatelessIf"):
209      input_index = _extract_input_index("then_branch")
210      if input_index != _extract_input_index("else_branch"):
211        raise AssertionError(
212            ("Expected cond branches ({} op) to each have the same "
213             "input->output mapping of resources.").format(node_def.op))
214      tensor_name = node_def.input[
215          # Ignore the `cond` input; the function inputs come after.
216          input_index + 1]
217    else:
218      # We assume there are no other ops types that will "forward" resource
219      # handles like this, so all other handles must have been created by the
220      # op. (Note that cond_v2 wraps resource handle outputs in optionals,
221      # which we'll end up accumulating).
222      raise ValueError("Taking gradient of a while loop which creates "
223                       "a resource in its body is not supported: %s (%s)"
224                       % (op_name, node_def.op))
225
226  return input_names.index(tensor_name)
227
228
229@tf_contextlib.contextmanager
230def clear_control_inputs():
231  """Clears the control inputs but preserves the ControlFlowContext.
232
233  This is needed to preserve the XLAControlFlowControl when clearing
234  control inputs for the gradient accumulators in while_v2.
235  `ops.control_dependencies` does not allow that.
236
237  Yields:
238    A context manager in which the ops created will not have any control inputs
239    by default but the control flow context is the same.
240  """
241  # pylint: disable=protected-access
242  control_flow_context = ops.get_default_graph()._get_control_flow_context()
243  with ops.control_dependencies(None):
244    ops.get_default_graph()._set_control_flow_context(control_flow_context)
245    yield
246  # pylint: enable=protected-access
247
248
249def _is_tpu_strategy(strategy):
250  return (strategy is not None and
251          strategy.__class__.__name__.startswith("TPUStrategy"))
252
253
254def _is_building_keras_layer():
255  # TODO(srbs): Remove this function when we no long support session with Keras.
256  keras_call_context_function = keras_deps.get_call_context_function()
257  if keras_call_context_function:
258    return keras_call_context_function().layer is not None
259  else:
260    return False
261
262
263def output_all_intermediates():
264  """Whether to output all intermediates of a functional control flow op.
265
266  The default behavior is to output intermediates only when building a Keras
267  Layer in graph mode and that too when certain other conditions are met:
268  1. We do not output intermediates if the functional control flow op
269     is being built inside a FuncGraph which is not a If/While graph. This
270     guards against outputting intermediates in eager mode since keras adds
271     tensors to a FuncGraph named "keras_graph" in that case. Also because we
272     do not output intermediates of tf.function (since this feature is only for
273     backwards compatibility) outputting intermediates of functional control
274     flow ops built inside tf.function is of no value.
275  2. We do not output intermediates when the compilation is using XLA or for a
276     TPU.
277  3. We do not output intermediates when a single threaded executor is used
278     since that does not perform inlining and pruning.
279
280  Returns:
281    A bool telling whether to output all intermediates.
282  """
283  if _EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE is not None:
284    return _EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
285  if in_defun():
286    return False
287  if (control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()) or
288      _is_tpu_strategy(distribution_strategy_context.get_strategy())):
289    return False
290  if (context.context().function_call_options.executor_type ==
291      "SINGLE_THREADED_EXECUTOR"):
292    return False
293  return _is_building_keras_layer()
294
295
296def get_func_graph(op, input_shapes, func_name):
297  """Generates and returns a FuncGraph for the given op and input_shapes."""
298  fdef = None
299  graph = op.graph
300  # Recursively search the func in graphs.
301  while graph is not None:
302    func = graph._get_function(func_name)  # pylint: disable=protected-access
303    if func is not None:
304      fdef = func.definition
305      break
306    if hasattr(graph, "outer_graph"):
307      graph = graph.outer_graph
308    else:
309      break
310
311  if fdef is None:
312    raise KeyError("%s cannot be found in the graph" % func_name)
313
314  # `op.graph` may not be the same as `ops.get_default_graph()` e.g.
315  # in the case of nested if ops or when the gradient is being computed
316  # from inside a Defun. We build the `func_graph` with `op.graph` as its
317  # `outer_graph`. This resembles how the `FuncGraph` was built in the
318  # forward pass. We need this so that we can resolve references to tensors
319  # in `func_graph` from its gradient graph in `_resolve_grad_inputs`.
320  with op.graph.as_default():
321    func_graph = function_def_to_graph.function_def_to_graph(
322        fdef, input_shapes)
323  return func_graph
324
325
326def get_op_and_outputs(op_or_outputs):
327  if isinstance(op_or_outputs, ops.Operation):
328    return op_or_outputs, []
329  elif not op_or_outputs:  # Empty list.
330    return None, []
331  else:
332    return op_or_outputs[0].op, op_or_outputs
333
334
335def graph_wrapped_for_higher_order_tape_gradients(graph):
336  """Check if `graph` is wrapped by `run_as_function_for_tape_gradients`."""
337  while graph is not None:
338    if "cflow_gradient_wrapper" in getattr(graph, "name", ""):
339      return True
340    graph = getattr(graph, "outer_graph", None)
341  return False
342
343
344def run_as_function_for_tape_gradients(make_op, inputs):
345  """Fix higher-order tape gradients by wrapping `make_op` in a function.
346
347  Args:
348    make_op: A function that takes a list of inputs and returns a list of output
349      tensors. This function should set any handle data relevant to its outputs
350      before returning.
351    inputs: A list of tensors to check for tape gradients and pass to
352      `make_op`. These should include all tensors used in `make_op`.
353
354  Returns:
355    Tensors corresponding to `make_op`'s output.
356  """
357  # GradientTapes created inside a function currently don't work well with
358  # un-wrapped control flow ops in that same function. Wrapping in an extra
359  # layer of intermediate function means we run extra logic in the function
360  # gradient code to record the correct intermediates on the tape.
361  #
362  # The function attribute inputs to control flow ops are not hashable, so we
363  # pass everything as a capture to bypass defun's caching.
364  if (gradients_util.PossibleTapeGradientTypes(inputs)
365      == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER
366      # We only need one function between the tape and the op; if we've already
367      # wrapped once, we stop wrapping to avoid infinite recursion.
368      and not (ops.get_default_graph().building_function
369               and "cflow_gradient_wrapper" in ops.get_default_graph().name)):
370    results = function.defun_with_attributes(
371        make_op,
372        autograph=False,
373        attributes=dict(func_name="cflow_gradient_wrapper"))(inputs)
374    return results
375  else:
376    return make_op(inputs)
377