xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/while_v2.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""while_v2 and gradient.
16
17This is a version of while_loop that emits a single While op, as well as the
18gradient function for While ops produced by while_loop. This will eventually
19replace the current tf.while_loop implementation once it reaches feature and
20performance parity.
21"""
22import collections
23
24from tensorflow.core.framework import attr_value_pb2
25from tensorflow.python.client import pywrap_tf_session as c_api
26from tensorflow.python.eager import backprop_util
27from tensorflow.python.framework import auto_control_deps_utils as acd
28from tensorflow.python.framework import constant_op
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import func_graph as func_graph_module
31from tensorflow.python.framework import indexed_slices
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import tensor_shape
34from tensorflow.python.framework import tensor_spec
35from tensorflow.python.framework import tensor_util
36from tensorflow.python.ops import array_ops
37from tensorflow.python.ops import control_flow_ops
38from tensorflow.python.ops import control_flow_util as util_v1
39from tensorflow.python.ops import control_flow_util_v2 as util
40from tensorflow.python.ops import default_gradient
41from tensorflow.python.ops import gen_functional_ops
42from tensorflow.python.ops import gen_resource_variable_ops
43from tensorflow.python.ops import gradients_util
44from tensorflow.python.ops import handle_data_util
45from tensorflow.python.ops import list_ops
46from tensorflow.python.ops import math_ops
47from tensorflow.python.ops import tensor_array_ops
48from tensorflow.python.ops import while_v2_indexed_slices_rewriter
49from tensorflow.python.util import compat
50from tensorflow.python.util import nest
51from tensorflow.python.util import object_identity
52from tensorflow.python.util import variable_utils
53
54# pylint: disable=protected-access
55
56# Controls parallelism in the presence of side-effecting ops like variable
57# operations, print, py_function, etc. Can be set to True, False, or
58# "stateless_cond" (default).
59# Note that loops without side-effecting operations always execute with maximum
60# parallelism, ignoring this setting. When False, loops with side-effecting ops
61# execute sequentially, one iteration at a time.
62# When True, loops with side-effecting ops may execute parts of different
63# iterations in parallel; caution: if the loop condition contains
64# side-effecting ops, this mode produces unspecified results.
65# Setting it to "stateless_cond" automatically sets this mode to True when
66# the loop condition is free of side-effecting ops.
67# TODO(b/152548567): Change this to "stateless_cond".
68glob_stateful_parallelism = False
69
70
71def while_loop(cond,
72               body,
73               loop_vars,
74               shape_invariants=None,
75               parallel_iterations=10,
76               maximum_iterations=None,
77               name=None,
78               return_same_structure=True,
79               back_prop=True):
80  """Like tf.while_loop, except emits a single While op."""
81  loop_vars = variable_utils.convert_variables_to_tensors(loop_vars)
82  # Keep the original loop_vars around to know which args were TensorArrays.
83  orig_loop_vars = loop_vars
84  flat_orig_loop_vars = nest.flatten(orig_loop_vars, expand_composites=True)
85  # Cache its length since we use it at multiple places below.
86  len_orig_loop_vars = len(orig_loop_vars)
87
88  # Convert TensorArrays to their flow variables. These get converted back to
89  # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and
90  # `wrapped_body` below.
91  loop_vars = _tensor_array_to_flow(loop_vars)
92  loop_vars = nest.map_structure(
93      ops.internal_convert_to_tensor_or_indexed_slices, loop_vars,
94      expand_composites=True)
95
96  # `loop_vars_signature` is a structure of TypeSpecs and has the same
97  # structure with the `orig_loop_vars`. If `shape_invariants` is not None, its
98  # shape information comes from `shape_invariants` instead of `orig_loop_vars`.
99  # It is used to pack flattened vars into structured vars.
100  if shape_invariants is not None:
101    loop_vars_signature = nest.map_structure(
102        control_flow_ops._shape_invariant_to_type_spec,
103        loop_vars, shape_invariants)
104  else:
105    loop_vars_signature = nest.map_structure(
106        control_flow_ops._shape_invariant_to_type_spec, loop_vars)
107
108  flat_shape_invariants = nest.map_structure(
109      lambda spec: spec.shape,
110      nest.flatten(loop_vars_signature, expand_composites=True))
111
112  if not name:
113    name = "while"
114
115  with ops.name_scope(name) as scope:
116    with ops.name_scope(None):
117      cond_name = util.unique_fn_name(scope, "cond")
118      body_name = util.unique_fn_name(scope, "body")
119    maximum_iterations_loop_var = _build_maximum_iterations_loop_var(
120        maximum_iterations)
121    loop_counter = constant_op.constant(
122        0,
123        dtype=maximum_iterations_loop_var.dtype
124        if maximum_iterations is not None else None,
125        name="loop_counter")
126    # Add loop counter needed for computing gradients.
127    loop_vars = [loop_counter, maximum_iterations_loop_var] + list(loop_vars)
128
129    func_graph_signature = (
130        [tensor_spec.TensorSpec.from_tensor(loop_counter),
131         tensor_spec.TensorSpec.from_tensor(maximum_iterations_loop_var)] +
132        list(loop_vars_signature))
133
134    # Automatic control dependencies are added in defuns, but not in v1
135    # graphs. Propagate that behavior here.
136    add_control_dependencies = ops.get_default_graph()._add_control_dependencies
137
138    def wrapped_cond(loop_counter, maximum_iterations_arg, *args):
139      """Extra `cond` wrapper that can handle the extra counter loop_var."""
140      # Convert the flow variables in `args` to TensorArrays. `args` should
141      # already have the same structure as `orig_loop_vars` but currently there
142      # is no nest.zip so we call `_pack_sequence_as` which flattens `args`,
143      # converts flows in `args` to TensorArrays and packs it into the
144      # structure of `loop_vars_signature`.
145      pred = cond(
146          *_pack_sequence_as(loop_vars_signature, flat_orig_loop_vars, args))
147      if (tensor_util.is_tf_type(pred) and
148          (pred.shape.dims is None or pred.shape.dims)):
149        pred = array_ops.squeeze_v2(pred)
150
151      if maximum_iterations is None:
152        return pred
153      else:
154        return math_ops.logical_and(
155            loop_counter < maximum_iterations_arg, pred)
156
157    # NOTE(skyewm): we set collections to the outer graph's collections for
158    # compatibility with TPUEstimator.
159    cond_graph = func_graph_module.func_graph_from_py_func(
160        cond_name,
161        wrapped_cond,
162        [],  # We provide signature instead of args.
163        {},
164        signature=func_graph_signature,
165        func_graph=util.WhileCondFuncGraph(
166            cond_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
167        add_control_dependencies=add_control_dependencies)
168
169    if glob_stateful_parallelism == "stateless_cond":
170      stateful_parallelism = (not any(
171          op._is_stateful for op in cond_graph.get_operations()))
172    else:
173      stateful_parallelism = glob_stateful_parallelism
174
175    def wrapped_body(loop_counter, maximum_iterations_arg, *args):
176      """Loop body augmented with counter update.
177
178      Args:
179        loop_counter: Loop counter which needs to be incremented in the body.
180        maximum_iterations_arg: Maximum iterations of the loop.
181        *args: List of args
182
183      Returns:
184        A list of tensors the same length as args.
185      """
186      # The function was created with a signature rather than tensors, so
187      # internal placeholders were created without handle data.
188      _copy_handle_data(nest.flatten(loop_vars[2:], expand_composites=True),
189                        nest.flatten(args, expand_composites=True))
190      # Capture the tensors already captured in cond_graph so that they appear
191      # in the same order in body_graph.external_captures.
192      for t in cond_graph.external_captures:
193        ops.get_default_graph().capture(t)
194
195      # Convert the flow variables in `args` to TensorArrays. `args` should
196      # already have the same structure as `orig_loop_vars` but currently there
197      # is no nest.zip so we call `_pack_sequence_as` which flattens `args`,
198      # converts flows in `args` to TensorArrays and packs it into the
199      # structure of `loop_vars_signature`.
200      outputs = body(
201          *_pack_sequence_as(loop_vars_signature, flat_orig_loop_vars, args))
202      if not nest.is_nested(outputs):
203        outputs = [outputs]
204      try:
205        # The legacy while_loop considers list and tuple to be the same
206        # structure.
207        nest.assert_same_structure(outputs, orig_loop_vars, check_types=False,
208                                   expand_composites=True)
209      except ValueError:
210        # Traditionally we consider variables and tensors to be the same
211        # structure.
212        vars1 = variable_utils.convert_variables_to_tensors(outputs)
213        vars2 = variable_utils.convert_variables_to_tensors(orig_loop_vars)
214        nest.assert_same_structure(vars1, vars2, check_types=False,
215                                   expand_composites=True)
216      outputs = _tensor_array_to_flow(outputs)
217
218      # TODO(srbs): Update lowering code to create _Enter nodes with
219      # is_constant=True for inputs that are directly passed to outputs.
220      return [loop_counter + 1, maximum_iterations_arg] + list(outputs)
221
222    body_graph = func_graph_module.func_graph_from_py_func(
223        body_name,
224        wrapped_body,
225        [],  # We provide signature instead of args.
226        {},
227        signature=func_graph_signature,
228        func_graph=util.WhileBodyFuncGraph(
229            body_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
230        add_control_dependencies=add_control_dependencies,
231        acd_record_initial_resource_uses=stateful_parallelism)
232    # Add external captures of body to the list of loop vars.
233    # Note that external tensors will be treated as loop invariants, i.e.,
234    # the value of that tensor in each iteration is the same as it was at the
235    # beginning of the loop execution.
236    deferred_external_captures = nest.flatten(
237        [c() for c in body_graph.deferred_external_captures],
238        expand_composites=True)
239    loop_vars = (
240        loop_vars + body_graph.external_captures + deferred_external_captures)
241    # TODO(srbs): Update lowering code to create _Enter nodes with
242    # is_constant=True for inputs that are directly passed to outputs.
243    body_graph.outputs.extend(body_graph.internal_captures)
244    body_graph.outputs.extend(body_graph.deferred_internal_captures)
245
246    # Capture the extra `external_captures` of `body_graph` in `cond_graph` so
247    # that it expects to receive those as arguments.
248    with cond_graph.as_default():
249      num_cond_captures = len(cond_graph.external_captures)
250      assert (cond_graph.external_captures ==
251              body_graph.external_captures[:num_cond_captures])
252      _duplicate_body_captures_in_cond(
253          cond_graph, body_graph.external_captures[num_cond_captures:] +
254          deferred_external_captures)
255
256    # Make sure that the shapes of the loop outputs are compatible with the
257    # shape invariants, or the shapes of the loop vars if the invariants are not
258    # specified.
259    num_flattened_outputs = len(nest.flatten(orig_loop_vars,
260                                             expand_composites=True))
261    # First var is loop counter and second var is maximum_iterations.
262    first_loop_var_index = 2
263    _check_shapes_compat(
264        body_graph.outputs[first_loop_var_index:first_loop_var_index +
265                           num_flattened_outputs],
266        flat_shape_invariants,
267        nest.flatten(loop_vars[first_loop_var_index:first_loop_var_index +
268                               len_orig_loop_vars], expand_composites=True))
269
270    num_original_outputs = len(body_graph.outputs)
271    if back_prop and util.output_all_intermediates():
272      # Export all tensors in the loop body that may be needed for gradient
273      # computation. We do this by accumulating the intermediate values in
274      # TensorLists.
275      intermediate_tensors = _get_intermediates(body_graph)
276
277      for intermediate_tensor in intermediate_tensors:
278        tensor_list = list_ops.empty_tensor_list(
279            element_dtype=intermediate_tensor.dtype,
280            element_shape=intermediate_tensor.shape,
281            max_num_elements=maximum_iterations)
282        loop_vars.append(tensor_list)
283        with cond_graph.as_default():
284          # Add a placeholder to cond_graph's inputs corresponding to the
285          # tensor_list.
286          cond_graph.capture(tensor_list)
287        with body_graph.as_default():
288          # Push the intermediate tensor to the tensor list. This captures the
289          # `tensor_list` as well.
290          appended_tensor_list = list_ops.tensor_list_push_back(
291              tensor_list, intermediate_tensor)
292          # Add this modified tensor list to the list of outputs.
293          body_graph.outputs.append(appended_tensor_list)
294
295    flattened_loop_vars = nest.flatten(loop_vars, expand_composites=True)
296    _check_num_inputs_outputs(cond_graph, body_graph,
297                              len(flattened_loop_vars))
298    _check_inputs_outputs_types_match(body_graph, flattened_loop_vars)
299
300    with ops.control_dependencies(
301        list(cond_graph.control_captures) + list(body_graph.control_captures)):
302      output_shapes = [t.shape for t in body_graph.outputs]
303      orig_loop_vars_range = slice(first_loop_var_index,
304                                   first_loop_var_index + num_flattened_outputs)
305      output_shapes[orig_loop_vars_range] = flat_shape_invariants
306
307      outputs = _build_while_op(
308          flattened_loop_vars,
309          cond_graph,
310          body_graph,
311          output_shapes=output_shapes,
312          parallel_iterations=parallel_iterations,
313          name=scope,
314          num_original_outputs=num_original_outputs,
315          stateful_parallelism=stateful_parallelism)
316    if not ops.get_default_graph().building_function:
317      # In V1 graph mode, return identities for each output of the While op,
318      # rather than the output of the While op directly. This makes pruning work
319      # if the output of while_loop() is fetched: the lowering pass converts the
320      # While outputs into IdentityN outputs, which if fetched will cause all
321      # ops in the body to be run (since it takes all exit ops as input). After
322      # lowering, each output identity op will end up with only the appropriate
323      # exit op as input.
324      outputs = tuple(array_ops.identity(t) for t in outputs)
325
326  output_loop_vars = outputs[first_loop_var_index:first_loop_var_index +
327                             num_flattened_outputs]
328  if not back_prop:
329    output_loop_vars = [array_ops.stop_gradient(t) for t in output_loop_vars]
330  outputs = _pack_sequence_as(
331      loop_vars_signature, flat_orig_loop_vars, output_loop_vars)
332
333  if return_same_structure:
334    return outputs
335
336  flattened_outputs = nest.flatten(outputs, expand_composites=True)
337  if len(flattened_outputs) == 1:
338    return flattened_outputs[0]
339  else:
340    return outputs
341
342
343@ops.RegisterGradient("StatelessWhile")
344@ops.RegisterGradient("While")
345def _WhileGrad(op, *grads):  # pylint: disable=invalid-name
346  """The gradient of a While op produced by while_loop."""
347  # Note that op is not always the same as while_op because the gradient tape,
348  # for eager mode compatibility, forgets information about the proper op. Since
349  # the loop cannot run in eager mode, however, we can safely introspect into
350  # the graph here.
351  while_op = op.outputs[0].op
352  cond_graph = _get_graph(while_op, "cond", "_cond_graph")
353  body_graph = _get_graph(while_op, "body", "_body_graph")
354  orig_num_params = len(body_graph.outputs)
355
356  maximum_iterations = op.inputs[1]
357  parallel_iterations = op.get_attr("parallel_iterations")
358
359  try:
360    num_original_outputs = while_op.get_attr("_num_original_outputs")
361  except:  # pylint: disable=bare-except
362    num_original_outputs = len(while_op.outputs)
363
364  try:
365    stateful_parallelism = while_op.get_attr("_stateful_parallelism")
366  except:  # pylint: disable=bare-except
367    stateful_parallelism = False
368
369  num_intermediates = len(while_op.outputs) - num_original_outputs
370  grads = [
371      _preprocess_grad(grad, body_out, while_in, while_out)  # pylint: disable=g-complex-comprehension
372      for grad, body_out, while_in, while_out in zip(
373          grads[:num_original_outputs],
374          body_graph.outputs[:num_original_outputs],
375          while_op.inputs[:num_original_outputs],
376          while_op.outputs[:num_original_outputs])
377  ] + [None] * num_intermediates
378
379  # Skip gradients with respect to the captures whenever possible.
380  if "skip_input_indices" in op.__dict__ and op.skip_input_indices is not None:
381    captures_start_index = (
382        len(body_graph.inputs) - len(body_graph.internal_captures))
383    for i in op.skip_input_indices:
384      if i >= captures_start_index:
385        grads[i] = None
386
387  # We compute the gradient for the sub-graph between trainable ys and xs
388  # with non-None incoming gradients. We later pad the None's to the list of
389  # outputs.
390  ys, xs, non_none_grads = zip(*[(y, x, grad) for (y, x, grad) in zip(
391      body_graph.outputs, body_graph.inputs, grads) if grad is not None])
392
393  body_grad_graph, args = _create_grad_func(
394      ys, xs, non_none_grads, cond_graph, body_graph,
395      util.unique_grad_fn_name(body_graph.name), op, maximum_iterations,
396      stateful_parallelism)
397
398  if body_grad_graph.while_op_needs_rewrite:
399    # Modify 'op' to output the intermediate accumulators needed by the grad
400    # function.
401    # NOTE(skyewm): if there are any active sessions, this modification to `op`
402    # may make them unrunnable!
403
404    cond_graph.name += "_rewritten"
405    body_graph.name += "_rewritten"
406
407    # `body_grad_graph.extra_inputs` here is equivalent to skimming off the new
408    # `body_graph.external_captures` added during `_create_grad_func`.
409    new_inputs = body_grad_graph.extra_inputs
410    new_outputs = body_graph.outputs[orig_num_params:]
411
412    while_op._set_func_attr("cond", util.create_new_tf_function(cond_graph))
413    while_op._set_func_attr("body", util.create_new_tf_function(body_graph))
414    if len(body_graph.output_types) != len(while_op.inputs) + len(new_inputs):
415      # Continuing leads to an invalid graph with disconnected inputs.
416      raise AssertionError(
417          "Inputs and outputs constructed for the forward op of a While "
418          "gradient don't match with 'output_types' at  "
419          f"{len(body_graph.output_types)},'inputs' at length "
420          f"{len(while_op.inputs)}, and 'new_inputs' at length "
421          f"{len(new_inputs)}. This doesn't make sense, please file a bug.")
422    while_op._set_type_list_attr("T", body_graph.output_types)
423    while_op._set_shape_list_attr("output_shapes", body_graph.output_shapes)
424    while_op._add_while_inputs(new_inputs)
425    while_op._add_outputs([t.dtype for t in new_outputs],
426                          [t.shape for t in new_outputs])
427    _copy_handle_data(new_outputs, while_op.outputs[orig_num_params:])
428
429  # Do not ignore grads wrt extra outputs when computing higher order
430  # derivatives.
431  while_op._set_attr("_num_original_outputs",
432                     attr_value_pb2.AttrValue(i=len(while_op.outputs)))
433
434  captured_inputs = _resolve_grad_captures(body_graph, body_grad_graph,
435                                           while_op)
436  loop_vars = args + captured_inputs
437
438  # This modifies body_grad_graph.
439  loop_vars = while_v2_indexed_slices_rewriter.rewrite_grad_indexed_slices(
440      grads, body_grad_graph, loop_vars, while_op.inputs)
441
442  def grad_cond(counter, unused_maximum_iterations_arg, forward_loop_iters,
443                *unused_args):
444    return counter < forward_loop_iters
445
446  grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name)
447  cond_grad_graph = func_graph_module.func_graph_from_py_func(
448      grad_cond_name, grad_cond, loop_vars, {},
449      func_graph=util.WhileCondFuncGraph(grad_cond_name))
450
451  _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars))
452
453  outputs = _build_while_op(
454      loop_vars,
455      cond_grad_graph,
456      body_grad_graph,
457      output_shapes=[t.shape for t in body_grad_graph.outputs],
458      parallel_iterations=parallel_iterations,
459      name="%s_grad" % while_op.name,
460      num_original_outputs=len(body_grad_graph.outputs),
461      stateful_parallelism=stateful_parallelism)
462
463  # See comment in while_loop.
464  outputs = [array_ops.identity(t) for t in outputs]
465  return _get_structured_grad_output(outputs, grads, body_grad_graph)
466
467
468def _build_while_op(loop_vars, cond_graph, body_graph, output_shapes,
469                    parallel_iterations, name, num_original_outputs,
470                    stateful_parallelism):
471  """Builds the functional StatelessWhile/While op."""
472  cond_stateful_ops = [
473      op for op in cond_graph.get_operations() if op._is_stateful
474  ]
475  body_stateful_ops = [
476      op for op in body_graph.get_operations() if op._is_stateful
477  ]
478  if (cond_stateful_ops or body_stateful_ops):
479    op_fn = gen_functional_ops._while
480  else:
481    op_fn = gen_functional_ops.stateless_while
482
483  def _make_op(inputs):
484    while_op, tensors = util.get_op_and_outputs(op_fn(
485        inputs,
486        util.create_new_tf_function(cond_graph),
487        util.create_new_tf_function(body_graph),
488        output_shapes=output_shapes,
489        parallel_iterations=parallel_iterations,
490        name=name))
491    _copy_handle_data(body_graph.outputs, tensors)
492    util.maybe_set_lowering_attr(while_op)
493    util.maybe_propagate_compile_time_consts_in_xla(while_op)
494    _set_read_only_resource_inputs_attr(while_op, [cond_graph, body_graph])
495    # This is needed so we do not compute derivative wrt these extra outputs.
496    while_op._set_attr("_num_original_outputs",
497                       attr_value_pb2.AttrValue(i=num_original_outputs))
498    while_op._set_attr("_stateful_parallelism",
499                       attr_value_pb2.AttrValue(b=stateful_parallelism))
500    # The while op may be created inside a tf.function, in which case ops
501    # needs to capture "through" it when taking gradients; outer_graph is used
502    # as a sanity check that capturing only happens from parent to child.
503    cond_graph.outer_graph = ops.get_default_graph()
504    body_graph.outer_graph = ops.get_default_graph()
505    while_op._cond_graph = cond_graph
506    while_op._body_graph = body_graph
507    return tensors
508  return util.run_as_function_for_tape_gradients(_make_op, loop_vars)
509
510
511def _get_intermediates(func_graph):
512  """Returns all tensors in `func_graph` that should be accumulated."""
513  # We currently accumulate output tensors of most ops in the function and rely
514  # on the pruning pass to get rid of the unused accumulators at runtime.
515  # However, this can bloat the GraphDef and make debugging harder so we perform
516  # some optimizations.
517  #
518  # Optimization we currently perform:
519  # 1. We do not accumulate tensors which already have an accumulator
520  #    in the loop body.
521  # 2. We do not accumulate outputs of Identity nodes. When building the
522  #    FuncGraph, we add an Identity node for each output (see
523  #    `AutomaticControlDependencies.mark_as_return`). Accumulating outputs
524  #    of all these nodes bloats the GraphDef quite a bit so we remove those.
525  #    Since the gradient of an Identity node does not rely on its forward op's
526  #    input this is safe to do.
527  #
528  # Other possible optimizations:
529  # 1. Only accumulate tensors that will be required by the backward pass.
530  #    This will require running the gradient pass and hence would increase the
531  #    graph building time for the forward pass.
532  # 2. Do not accumulate Const nodes created inside the loop body.
533  # 3. Do not accumulate loop vars that are returned as-is just like captured
534  #    tensors.
535  intermediates = []
536  reverse_captures = dict((v.ref(), k) for k, v in func_graph.captures)
537
538  for op in func_graph.get_operations():
539    if op.type == "Identity":
540      continue
541    # Accumulating mutexes can cause deadlock.
542    if op.type == "MutexLock":
543      continue
544    for o in op.outputs:
545      if (o is not func_graph.inputs[0] and  # Loop counter.
546          o.dtype != dtypes.resource and  # Do not accumulate resource tensors.
547          _get_accumulator(o) is None and  # Has existing accumulator.
548          o.ref() not in reverse_captures
549         ):  # Captured value, hence loop invariant.
550        intermediates.append(o)
551  return intermediates
552
553
554def _preprocess_grad(grad, body_graph_output, while_op_input, while_op_output):
555  """Returns the initial gradient to be used for a given output tensor.
556
557  Args:
558    grad: the original gradient Tensor passed to the gradient function.
559    body_graph_output: the corresponding Tensor in the body graph.
560    while_op_input: the corresponding Tensor input of the While op.
561    while_op_output: the corresponding Tensor output of the While op.
562
563  Returns:
564    A Tensor or None.
565  """
566  # Set the incoming gradient of non-trainable inputs to None. It is possible
567  # that we receive non-None gradients for non-trainable types in nested while
568  # loops because we accumulate outputs of the inner while as variant tensors
569  # which are trainable and hence receive zeros_like tensors in the gradient
570  # pass. The non-trainable tensors then receive the popped zeros tensor from
571  # this zeros variant. The gradient for the loop vars corresponding to these
572  # tensors is None or zeros (this happens only if the loop var is accumulated
573  # as well) in _grad_fn so we reset these.
574  # TODO(b/118712257): Remove once we can handle None output grads in _grad_fn.
575  if not _is_trainable(body_graph_output):
576    return None
577
578  # GradientTape initializes resource and variant grads as None instead of
579  # zeros. Set to zeros so _GradientsHelper computes the gradients instead of
580  # returning None.
581  # TODO(b/143286622): The supports_default_grad check is needed
582  # because While op emits non-differentiable resource tensors
583  # as outputs. Remove this check when that is not the case.
584  # Note: We use `while_op_input` instead of `while_op_output` for the call
585  # to `supports_default_grad` because `while_op_output` may be missing
586  # handle_data if the While is in a restored saved model.
587  if (while_op_output.dtype in (dtypes.resource, dtypes.variant) and
588      default_gradient.supports_default_grad(while_op_input) and grad is None):
589    return _zeros_like(while_op_input, while_op_output)
590
591  # Convert IndexedSlices to dense tensors since it is unlikely that downstream
592  # gradient functions with properly handle indexed slices. This is similar to
593  # what we do in tf.function gradients.
594  if isinstance(grad, indexed_slices.IndexedSlices):
595    return ops.convert_to_tensor(grad)
596
597  return grad
598
599
600# TODO(skyewm): make this return constants if op_output's shape is fully
601# defined (this can be done by checking the "shape" attr of resource vars).
602def _zeros_like(op_input, op_output):
603  """Like array_ops.zeros_like() but also accepts resource var handles."""
604  if op_output.dtype == dtypes.resource:
605    # Note: We use `op_input` instead of `op_output` to get the zeros dtype
606    # because `op_output` may be missing handle_data if the While is in a
607    # restored saved model.
608    return array_ops.zeros(
609        gen_resource_variable_ops.variable_shape(op_output),
610        dtype=default_gradient.get_zeros_dtype(op_input))
611  return array_ops.zeros_like(op_output)
612
613
614def _is_trainable(tensor):
615  """Returns whether the given tensor is trainable."""
616  if not backprop_util.IsTrainable(tensor):
617    return False
618
619  # Special case: untrainable accumulator output. The gradients algorithm
620  # doesn't know about tensor lists of untrainable elements. In theory the
621  # tensor list gradient functions should return None as appropriate, but
622  # because we can't return None from the gradient function we filter out
623  # untrainable accumulator output here to avoid computing the gradient at all.
624  if tensor.op.type == "TensorListPopBack" and tensor.value_index == 0:
625    assert tensor.dtype == dtypes.variant
626    element_type = tensor.op.get_attr("element_dtype")
627    return backprop_util.IsTrainable(element_type)
628
629  return True
630
631
632def _get_graph(while_op, func_attr_name, attr_graph_name):
633  """Returns `FuncGraph` for the given function attribute.
634
635  Args:
636    while_op: The While Operation.
637    func_attr_name: string
638    attr_graph_name: cached forward graph name
639
640  Returns:
641    `FuncGraph`
642  """
643  func_graph = getattr(while_op, attr_graph_name, None)
644  if func_graph is None:
645    # TODO(srbs): Handle TensorShapeProto in function_def_to_graph.input_shapes.
646    input_shapes = [
647        tensor_shape.TensorShape(s) for s in while_op.get_attr("output_shapes")
648    ]
649    func_name = while_op.get_attr(func_attr_name).name
650    func_graph = util.get_func_graph(while_op, input_shapes, func_name)
651  func_graph._while = while_op
652  return func_graph
653
654
655def _create_grad_func(ys, xs, grads, cond_graph, body_graph, name, while_op,
656                      maximum_iterations, stateful_parallelism):
657  """Builds and returns the gradient FuncGraph of `func_graph` and its args.
658
659  The returned grad_func_graph must be called with the returned
660  args + grad_func_graph.captures.
661
662  Args:
663    ys: A `Tensor` or list of tensors to be differentiated.
664    xs: A `Tensor` or list of tensors to be used for differentiation.
665    grads: The incoming grads for `ys`.
666    cond_graph: FuncGraph for the forward cond function.
667    body_graph: FuncGraph for the forward body function.
668    name: Name of the returned gradient function.
669    while_op: The forward While op.
670    maximum_iterations: Tensor. The maximum number of iterations.
671    stateful_parallelism: Bool, see tf.while_loop.
672
673  Returns:
674    2-tuple of (grad_func_graph, args).
675  """
676  assert len(ys) == len(grads)
677
678  total_iters = while_op.outputs[0]
679  counter = constant_op.constant(
680      0, dtype=total_iters.dtype, name="grad_counter")
681
682  # Build frozen sets so that we do not have linear time lookups in
683  # `_is_loop_invariant`. Note: `body_graph.inputs` and `body_graph.outputs`
684  # may get updated during gradient computation because we add accumulators to
685  # the forward op. However, those are not loop invariants so wouldn't affect
686  # the output of `_is_loop_invariant`. Also we would never attempt to capture
687  # those accumulators so `_is_loop_invariant` should never receive those new
688  # tensors as args.
689  body_graph_inputs = object_identity.ObjectIdentitySet(body_graph.inputs)
690  body_graph_outputs = object_identity.ObjectIdentitySet(body_graph.outputs)
691
692  args = [counter, maximum_iterations, total_iters] + list(grads)
693  # Note: The returned function does not have `args` in the list of
694  # `external_captures`.
695  grad_func_graph = func_graph_module.func_graph_from_py_func(
696      name,
697      lambda *args: _grad_fn(ys, xs, args, body_graph),
698      args, {},
699      func_graph=_WhileBodyGradFuncGraph(name, cond_graph, body_graph,
700                                         maximum_iterations, while_op,
701                                         body_graph_inputs, body_graph_outputs),
702      acd_record_initial_resource_uses=stateful_parallelism)
703
704  # Update the list of outputs with tensors corresponding to the captured
705  # tensors. We capture 3 types of tensors when building the grad fn:
706  # 1. Accumulators for forward graph intermediates which are not loop
707  #    invariants. The outputs corresponding to these are populated in
708  #    `internal_capture_to_output` by `_WhileBodyGradFuncGraph`.
709  # 2. Resources, which are output as is.
710  # 3. Forward graph loop invariants, which are output as is.
711  for external_capture, internal_capture in grad_func_graph.captures:
712    if (ops.tensor_id(internal_capture)
713        in grad_func_graph.internal_capture_to_output):
714      new_output = grad_func_graph.internal_capture_to_output[ops.tensor_id(
715          internal_capture)]
716    else:
717      raise ValueError(
718          f"Tensor {str(internal_capture)} which captures "
719          f"{str(external_capture)} is in list of "
720          f"internal_captures but not in internal_capture_to_output.")
721    grad_func_graph.outputs.append(new_output)
722    grad_func_graph.structured_outputs.append(new_output)
723
724  return grad_func_graph, args
725
726
727def _grad_fn(ys, xs, args, func_graph):
728  """Computes the gradient of `func_graph` in the current graph.
729
730  This function builds the gradient graph of the corresponding forward-pass
731  `func_graph` by differentiating `func_graph`'s outputs w.r.t. its inputs.
732
733  Args:
734    ys: A `Tensor` or list of tensors to be differentiated.
735    xs: A `Tensor` or list of tensors to be used for differentiation.
736    args: The input arguments.
737      args[0] - Loop counter
738      args[1] - Total number of iterations.
739      args[2] - maximum_iterations.
740      args[3:] - Incoming gradients for `ys`.
741    func_graph: function.FuncGraph. The corresponding forward-pass function.
742
743  Returns:
744    The output gradient Tensors.
745  """
746  grad_ys = args[3:]
747
748  # Build the gradient graph. Note that this builds the gradient computation of
749  # func_graph in the current graph, which requires capturing tensors from
750  # func_graph. The captured func_graph tensors are resolved to external tensors
751  # after the forward While op has been rewritten in _resolve_grad_captures.
752  # TODO(srbs): Mark GradientsHelper as public?
753  grad_outs = gradients_util._GradientsHelper(
754      ys, xs, grad_ys=grad_ys, src_graph=func_graph,
755      unconnected_gradients="zero")
756
757  # TODO(b/118712257): Handle the case when grad_outs has None's e.g. when there
758  # is a tf.StopGradient in the loop body.
759  assert all(g is not None for g in grad_outs)
760  counter = args[0]
761  maximum_iterations = args[1]
762  total_iters = args[2]
763  return [counter + 1, maximum_iterations, total_iters] + grad_outs
764
765
766def _resolve_grad_captures(body_graph, body_grad_graph, while_op):
767  """Returns the tensors to pass as captured inputs to `body_grad_graph`.
768
769  `body_grad_graph` may have external references to:
770  1. Its outer graph containing the input gradients. These are left as-is.
771  2. Accumulators captured from the forward-pass graph. These should have been
772     added as `while_op` outputs after the gradient graph was built. We replace
773     these with the corresponding output of `while_op`, i.e. a tensor in
774     `body_graph.outer_graph`. In the case of nested control flow or functions,
775     the gradient logic handling `body_grad_graph.outer_graph` will make sure
776     the tensor from `body_graph.outer_graph` is also correctly captured.
777
778  Args:
779    body_graph: FuncGraph. The forward-pass body function.
780    body_grad_graph: FuncGraph. The body gradients function.
781    while_op: The forward-pass While Operation calling `body_graph`.
782
783  Returns:
784    A list of input tensors to be passed as the captured inputs to
785    `body_grad_graph`.
786  """
787  new_capture_inputs = []
788  for t in body_grad_graph.external_captures:
789    # Resolve tensors captured from the forward graph to the outputs of the
790    # forward while_op.
791    if t.graph == body_graph:
792      # Captured accumulator or loop invariant.
793      for i, output in enumerate(t.graph.outputs):
794        if output is t:
795          t = while_op.outputs[i]
796          break
797
798      # Note: We rely on the capturing logic of the gradient While op graph to
799      # correctly capture the tensors in `body_graph.outer_graph`. Both cond_v2
800      # and while_v2 handle this while building their gradient functions.
801      assert t.graph == body_graph.outer_graph
802
803    new_capture_inputs.append(t)
804  return new_capture_inputs
805
806
807def _get_structured_grad_output(outputs, grads, body_grad_graph):
808  """Returns the values that should be returned from the while grad function.
809
810  Args:
811    outputs: the raw Tensor outputs of the grad While op.
812    grads: the input gradients to the gradient function.
813    body_grad_graph: _WhileBodyGradFuncGraph.
814
815  Returns:
816    A list of gradient values. May include Nones.
817  """
818  result = []
819  # outputs[0] is the loop counter.
820  # outputs[1] is maximum_iterations.
821  # outputs[2] is the total number of loop iterations.
822  outputs_idx = 3
823  structured_outputs_idx = 3
824  for g in grads:
825    # Set None as the output gradient for tensors with None input gradient.
826    if g is None:
827      result.append(None)
828      continue
829    output = body_grad_graph.structured_outputs[structured_outputs_idx]
830    structured_outputs_idx += 1
831    if isinstance(output, indexed_slices.IndexedSlices):
832      # TODO(skyewm): is there a more robust way to determine the order of
833      # flattened IndexedSlices components?
834      result.append(indexed_slices.IndexedSlices(
835          values=outputs[outputs_idx],
836          indices=outputs[outputs_idx + 1],
837          dense_shape=outputs[outputs_idx + 2]))
838      outputs_idx += 3
839    else:
840      assert isinstance(output, ops.Tensor)
841      result.append(outputs[outputs_idx])
842      outputs_idx += 1
843
844  return result
845
846
847def _get_accumulator(tensor):
848  r"""Returns TensorList if any containing accumulated values of tensor.
849
850  We try to find a pattern of the form:
851
852     input_tl   tensor
853        \        /
854    (TensorListPushBack)
855            |
856        output_tl
857
858  which satisfies the following conditions:
859
860  1. input_tl must be in tensor.graph.inputs.
861  2. output_tl or Identity(output_tl) must be in tensor.graph.outputs.
862  3. tensor.graph.input_index(input_tl) == tensor.graph.output_index(output_t).
863
864  output_tl or Identity(output_tl) (whichever is in tensor.graph.outputs) is
865  returned if such a pattern is found else None is returned.
866
867  Args:
868    tensor: The Tensor to be accumulated.
869
870  Returns:
871    A variant tensor in the same graph as `tensor` or None if no accumulator is
872    found.
873  """
874  assert isinstance(tensor.graph, func_graph_module.FuncGraph)
875
876  def get_func_graph_output(t):
877    """Returns t or Identity(t) whichever exists in graph outputs else None."""
878    for output in tensor.graph.outputs:
879      if output is t:
880        return t
881    # tf.defun adds an Identity for each output, check whether that is the case.
882    identity_op = t.consumers()[0]
883    if (identity_op.type == "Identity" and
884        any(identity_op.outputs[0] is t for t in tensor.graph.outputs)):
885      return identity_op.outputs[0]
886    return None
887
888  for consumer in tensor.consumers():
889    # Find the consumer that is a TensorListPushBack node whose TensorList input
890    # is in the list of function inputs.
891    if consumer.type != "TensorListPushBack":
892      continue
893
894    accum_input_idx = -1
895    for accum_input_idx, inp in enumerate(tensor.graph.inputs):
896      if inp is consumer.inputs[0]:
897        break
898    else:
899      continue
900
901    output = get_func_graph_output(consumer.outputs[0])
902    if output is None:
903      # The TensorList output of `consumer` is not in the list of function
904      # outputs.
905      continue
906
907    for accum_output_idx, out in enumerate(tensor.graph.outputs):
908      if out is output:
909        if accum_input_idx == accum_output_idx:
910          return output
911        break
912
913  return None
914
915
916OptimizedReductionOpsCacheKey = collections.namedtuple(
917    "OptimizedReductionOpsCacheKey", [
918        "op_type",
919        "inputs",
920        "dtypes",
921        "input_types",
922        "name",
923        "attrs",
924        "op_def",
925        "compute_device",
926    ])
927
928
929class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph):
930  """FuncGraph for the gradient function of the body of a While op.
931
932  Contains the logic for capturing the tensors from the body of the forward
933  While op which is as follows:
934  1. If the tensor is of resource type (these are not accumulated):
935     a. Ensure that the tensor is a loop invariant, i.e., it exists in both loop
936        inputs and outputs at the same index.
937     b. Lookup the corresponding resource tensor in the forward outer graph and
938        try to capture that.
939  2. If the tensor is not of resource type:
940     a. Create an accumulator for that tensor and output it from the forward
941        pass. Note this also requires adding it as an input to the forward pass.
942     b. Capture the accumulator from the forward pass in this FuncGraph. This
943        will later be resolved to the correct output of the forward While op.
944     c. Pop a value from the captured placeholder and use it as the captured
945        value for the forward pass tensor.
946
947  This only allows capturing tensors in the forward graph. A ValueError is
948  raised if an attempt is made to capture a tensor not in the forward graph.
949  To manually capture a tensor that is not in the forward graph, call `capture`
950  with `allowlisted=True`.
951
952  Note: The `captures` dict does not contain the forward tensor since it is not
953  directly captured. It contains the accumulator corresponding to this forward
954  tensor.
955
956  Attributes:
957    while_op_needs_rewrite: True if any non-resource intermediates were
958      captured, meaning the forward While op needs to be rewritten to output the
959      corresponding accumulators.
960    extra_inputs: list of EmptyTensorList tensors to be used as initial input to
961    the new accumulators in the forward graph. It may also contain external
962    captures of the custom gradient function.
963    internal_capture_to_output: dict from a tensor_id(captured placeholder) to
964      the corresponding tensor that needs to be added to the list of outputs.
965      For instance, when capturing an accumulator TensorList this contains the
966      TensorList obtained after popping a tensor from the list. Other entries
967      in this dict are expected, though not enforced, to be identities.
968      This dict is needed because these output tensors need to be added to
969      FuncGraph.outputs "after" the tensors returned from the gradient function.
970  """
971
972  def __init__(self, name, forward_cond_graph, forward_body_graph,
973               maximum_iterations, forward_while_op, body_graph_inputs,
974               body_graph_outputs):
975    super(_WhileBodyGradFuncGraph, self).__init__(name)
976    self.extra_inputs = []
977    self.internal_capture_to_output = {}
978    # FuncGraph for the body of the forward While op.
979    self._forward_graph = forward_body_graph
980    # FuncGraph for the cond of the forward While op.
981    self._forward_cond_graph = forward_cond_graph
982    self._maximum_iterations = maximum_iterations
983    self._forward_while_op = forward_while_op
984    # Dict from forward intermediate tensor to its indirectly captured tensor
985    # in this graph. Indirect capturing happens in two ways:
986    # 1. For non-resource tensors we capture their accumulators from the forward
987    #    outer graph and pop values from that accumulator inside this graph
988    #    using TensorListPopBack.
989    # 2. For resource tensors we directly capture their corresponding tensor
990    #    in the forward outer graph.
991    self._indirect_captures = {}
992
993  @property
994  def while_op_needs_rewrite(self):
995    return self.extra_inputs
996
997  def _create_op_internal(
998      self,
999      op_type,
1000      inputs,
1001      dtypes=None,  # pylint: disable=redefined-outer-name
1002      input_types=None,
1003      name=None,
1004      attrs=None,
1005      op_def=None,
1006      compute_device=True):
1007    # For a reduction op, if op is in the gradient body graph and its input is
1008    # from the forward graph, moving op to the forward graph means we would
1009    # store the tensor after the reduction as opposed to the tensor before
1010    # reduction, and therefore could significantly reduce memory consumption.
1011    # For now, we do this only for a few ops.
1012    #
1013    # We don't do this if any input tensor has already been accumulated. This
1014    # can happen if we output all intermediates in the forward pass.
1015    #
1016    # If in XLA context, do not move constant ops to forward pass as pushing to
1017    # and popping from a TensorList removes the constant property of an op and
1018    # breaks XLA compilation, which requires certain inputs to be compile-time
1019    # constant for certain ops.
1020    #
1021    # This optimization is currently also disabled when under a persistent tape,
1022    # since it leads to an unbounded number of side outputs. With caching it may
1023    # be possible to re-enable it.
1024    optimized_reduction_ops = {
1025        "Shape", "Size", "Rank", "TensorListElementShape", "TensorListLength"
1026    }
1027    if (op_type in optimized_reduction_ops and
1028        not util.output_all_intermediates() and
1029        all(input.graph is self._forward_graph for input in inputs) and
1030        all(_get_accumulator(input) is None for input in inputs) and
1031        not util_v1.GraphOrParentsInXlaContext(self._forward_graph) and
1032        not util.graph_wrapped_for_higher_order_tape_gradients(
1033            self._forward_graph)):
1034      return self._move_op_to_forward_graph(
1035          op_type,
1036          inputs,
1037          dtypes=dtypes,
1038          input_types=input_types,
1039          name=name,
1040          attrs=attrs,
1041          op_def=op_def,
1042          compute_device=compute_device)
1043
1044    return super(_WhileBodyGradFuncGraph, self)._create_op_internal(
1045        op_type,
1046        inputs,
1047        dtypes=dtypes,
1048        input_types=input_types,
1049        name=name,
1050        attrs=attrs,
1051        op_def=op_def,
1052        compute_device=compute_device)
1053
1054  def _move_op_to_forward_graph(
1055      self,
1056      op_type,
1057      inputs,
1058      dtypes=None,  # pylint: disable=redefined-outer-name
1059      input_types=None,
1060      name=None,
1061      attrs=None,
1062      op_def=None,
1063      compute_device=True):
1064    # We have a cache of reduction ops that have already been moved to the
1065    # forward graph, and we will check it first to avoid moving an op twice.
1066    if not hasattr(self._forward_graph, "_optimized_reduction_ops_cache"):
1067      self._forward_graph._optimized_reduction_ops_cache = {}
1068    cache_key = self._get_optimized_reduction_ops_cache_key(
1069        op_type, inputs, dtypes, input_types, name, attrs, op_def,
1070        compute_device)
1071    cached_op = self._forward_graph._optimized_reduction_ops_cache.get(
1072        cache_key)
1073    if cached_op is not None:
1074      # This op has already been moved to the forward graph and we have it in
1075      # the cache.
1076      return cached_op
1077
1078    with self._forward_graph.as_default():
1079      # `name` was built using name_scope stack of gradient graph and may not
1080      # be unique in the forward graph. `Graph.create_op` does not uniquify
1081      # names which are name scopes i.e. end in `/`. To ensure that the op
1082      # created gets a unique name in the forward graph we get rid of the
1083      # trailing slash.
1084      name = ops.name_from_scope_name(name)
1085      result = self._forward_graph._create_op_internal(
1086          op_type,
1087          inputs,
1088          dtypes=dtypes,
1089          input_types=input_types,
1090          name=name,
1091          attrs=attrs,
1092          op_def=op_def,
1093          compute_device=compute_device)
1094
1095      # Store the op we just moved to the forward graph so that it does
1096      # not need to be added there again.
1097      self._forward_graph._optimized_reduction_ops_cache[cache_key] = result
1098      return result
1099
1100  def _get_optimized_reduction_ops_cache_key(
1101      self,
1102      op_type,
1103      inputs,
1104      dtypes=None,  # pylint: disable=redefined-outer-name
1105      input_types=None,
1106      name=None,
1107      attrs=None,
1108      op_def=None,
1109      compute_device=True):
1110    # We need all elements of CacheKey to be hashable.
1111    inputs = tuple(map(lambda t: t.ref(), inputs))
1112
1113    if dtypes is not None:
1114      dtypes = tuple(dtypes)
1115
1116    if input_types is not None:
1117      input_types = tuple(input_types)
1118
1119    if attrs is not None:
1120      hashable_attrs = []
1121      for attr_name, attr_value in sorted(attrs.items()):
1122        hashable_attrs.append((attr_name, attr_value.SerializeToString()))
1123      attrs = tuple(hashable_attrs)
1124
1125    if op_def is not None:
1126      op_def = op_def.SerializeToString()
1127
1128    return OptimizedReductionOpsCacheKey(op_type, inputs, dtypes, input_types,
1129                                         name, attrs, op_def, compute_device)
1130
1131  def _capture_helper(self, tensor, name):
1132    """Implements the capturing described in the class docstring."""
1133    captured_tensor = self._indirect_captures.get(ops.tensor_id(tensor))
1134    if captured_tensor is not None:
1135      return captured_tensor
1136
1137    if tensor.graph is not self._forward_graph:
1138      already_captured = self.captured(tensor)
1139      captured_tensor = super(_WhileBodyGradFuncGraph, self)._capture_helper(
1140          tensor, name)
1141      if not already_captured:
1142        # Adds the captured tensor to the list of outputs so that the input
1143        # and output signatures match.
1144        self.internal_capture_to_output[ops.tensor_id(
1145            captured_tensor)] = captured_tensor
1146        self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor
1147      return captured_tensor
1148
1149    while tensor.op.type == "Identity":
1150      # We do not accumulate the output of identity nodes so we try to capture
1151      # the input of the Identity node instead.
1152      tensor = tensor.op.inputs[0]
1153
1154    captured_tensor = self._indirect_captures.get(ops.tensor_id(tensor))
1155    if captured_tensor is not None:
1156      return captured_tensor
1157
1158    # No need to accumulate loop invariants. Capture them directly.
1159    # The captured tensor gets resolved to the corresponding while output in
1160    # `_resolve_grad_captures`.
1161    if _is_loop_invariant(tensor, self._forward_graph.inputs,
1162                          self._forward_graph.outputs):
1163      captured_tensor = super(_WhileBodyGradFuncGraph,
1164                              self)._capture_helper(tensor, name)
1165      # Add to `internal_capture_to_output` so that this gets added to the list
1166      # of outputs.
1167      self.internal_capture_to_output[ops.tensor_id(
1168          captured_tensor)] = captured_tensor
1169      self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor
1170      return captured_tensor
1171
1172    # Do not accumulate Const nodes. Instead copy them directly in the backward
1173    # graph.
1174    # TODO(srbs): This just checks for `Const` nodes. Consider checking for
1175    # graph compile time consts in general.
1176    # TODO(srbs): Consider making this a loop input.
1177    if constant_op.is_constant(tensor):
1178      real_value = constant_op.constant(
1179          tensor_util.constant_value(tensor), dtype=tensor.dtype)
1180      self._indirect_captures[ops.tensor_id(tensor)] = real_value
1181      return real_value
1182
1183    # Resource tensors are not accumulated and handled specially.
1184    if tensor.dtype == dtypes.resource:
1185      return self._resource_capture_helper(tensor)
1186
1187    # Create or find an existing accumulator output for `tensor` in the forward
1188    # graph, and fetch from this accumulator in the gradient graph to get the
1189    # raw intermediate value.
1190    accumulator = _get_accumulator(tensor)
1191    if accumulator is None:
1192      # Create the initial empty tensor list.
1193      #
1194      # Note: We clear the control dependencies to avoid a cycle in case a
1195      # control tensor has an input path to an output of the  forward While.
1196      #
1197      # E.g.:
1198      # x = tf.while_loop(...)
1199      # y = f(x)
1200      # with tf.control_dependencies([y]):
1201      #   tf.gradients(y, x)
1202      #
1203      # Since the EmptyTensorList is fed back into the forward While, not
1204      # removing the control edge would cause a cycle.
1205      with self._forward_graph.outer_graph.as_default():
1206        with util.clear_control_inputs():
1207          tensor_list = list_ops.empty_tensor_list(
1208              element_dtype=tensor.dtype,
1209              element_shape=tensor.shape,
1210              max_num_elements=self._maximum_iterations,
1211              name=_build_accumulator_name(tensor))
1212      self.extra_inputs.append(tensor_list)
1213
1214      # Push the intermediate tensor to the tensor list. This captures
1215      # `tensor_list`.
1216      with self._forward_graph.as_default():
1217        accumulator = list_ops.tensor_list_push_back(tensor_list, tensor)
1218      # Add the modified tensor list to the list of outputs. This output will be
1219      # all the accumulated values.
1220      self._forward_graph.outputs.append(accumulator)
1221
1222      # Capture in the cond graph as well so the forward cond and body inputs
1223      # match.
1224      with self._forward_cond_graph.as_default():
1225        self._forward_cond_graph.capture(tensor_list)
1226
1227    # Capture the accumulator tensor list in the gradient graph directly from
1228    # the forward graph -- we'll later modify this to capture the final list
1229    # output by the forward While op instead.
1230    captured_accumulator = super(_WhileBodyGradFuncGraph, self)._capture_helper(
1231        accumulator, name)
1232
1233    # Pop the intermediate value from the tensor list in the gradient graph.
1234    new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back(
1235        captured_accumulator, element_dtype=tensor.dtype)
1236
1237    self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor
1238    self.internal_capture_to_output[ops.tensor_id(
1239        captured_accumulator)] = new_tensor_list
1240    return captured_tensor
1241
1242  def _resource_capture_helper(self, tensor):
1243    """Returns the captured resource tensor.
1244
1245    Resource-type tensors are not accumulated. If a resource tensor exists in
1246    the loop body it must either be a loop input or an output of a nested While
1247    op inside the loop body which had captured the external resource.
1248
1249    Args:
1250      tensor: the external resource Tensor to be captured.
1251
1252    Returns:
1253      Tensor in this graph.
1254    """
1255    assert tensor.dtype == dtypes.resource
1256
1257    forward_graph_input_names = [t.name for t in self._forward_graph.inputs]
1258    forward_graph_name_to_opdef = {
1259        op.name: op.node_def for op in self._forward_graph.get_operations()}
1260    index = util.resource_input_index(
1261        tensor.name, forward_graph_input_names,
1262        forward_graph_name_to_opdef,
1263        self._forward_graph._functions)
1264
1265    input_placeholder = self._forward_graph.inputs[index]
1266    tensor_in_outer_graph = self._forward_graph._while.inputs[index]
1267
1268    assert input_placeholder.dtype == dtypes.resource
1269    assert tensor_in_outer_graph.dtype == dtypes.resource
1270    # This must be a loop invariant. However, infrastructure
1271    # (e.g. tf.vectorized_map) may insert identity nodes, function calls, conds,
1272    # etc. which take and return the resource tensor unmodified; this means that
1273    # the Python objects may differ.
1274    if index != util.resource_input_index(
1275        self._forward_graph.outputs[index].name, forward_graph_input_names,
1276        forward_graph_name_to_opdef,
1277        self._forward_graph._functions):
1278      raise AssertionError(
1279          f"Resource tensors must be loop invariants {tensor_in_outer_graph}")
1280
1281    self._indirect_captures[ops.tensor_id(tensor)] = self.capture(
1282        tensor_in_outer_graph)
1283    return self._indirect_captures[ops.tensor_id(tensor)]
1284
1285
1286def _check_shapes_compat(flat_output_tensors, flat_shape_invariants,
1287                         flat_input_tensors):
1288  for (t, shape, input_t) in zip(flat_output_tensors, flat_shape_invariants,
1289                                 flat_input_tensors):
1290    if not control_flow_ops._ShapeLessThanOrEqual(t.shape, shape):
1291      raise ValueError(
1292          f"Input tensor `{input_t.name}` enters the loop with shape {shape}, "
1293          f"but has shape {t.shape} after one iteration. To allow the shape to "
1294          "vary across iterations, use the `shape_invariants` argument of "
1295          "tf.while_loop to specify a less-specific shape.")
1296
1297
1298def _check_num_inputs_outputs(cond_graph, body_graph, num_flattened_loop_vars):
1299  """Checks the number of inputs/outputs of `cond_graph` and `body_graph`."""
1300  assert len(cond_graph.inputs) == num_flattened_loop_vars, (
1301      "cond_graph takes %d inputs; Expected: %d" % (len(cond_graph.inputs),
1302                                                    num_flattened_loop_vars))
1303  assert len(cond_graph.outputs) == 1, (
1304      "cond_graph has %d outputs; Expected: 1" % len(cond_graph.outputs))
1305  assert len(body_graph.inputs) == num_flattened_loop_vars, (
1306      "body_graph takes %d inputs; Expected: %d" % (len(body_graph.inputs),
1307                                                    num_flattened_loop_vars))
1308  assert len(body_graph.outputs) == num_flattened_loop_vars, (
1309      "body_graph has %d outputs; Expected: %d" % (len(body_graph.outputs),
1310                                                   num_flattened_loop_vars))
1311
1312
1313def _check_inputs_outputs_types_match(body_graph, flattened_loop_vars):
1314  for inp, out, loop_var in zip(body_graph.inputs, body_graph.outputs,
1315                                flattened_loop_vars):
1316    if inp.dtype != out.dtype:
1317      raise TypeError(
1318          f"Loop var {loop_var.name} enters the loop with type {inp.dtype} "
1319          f"but has type {out.dtype} after 1 iteration. {loop_var.name} type "
1320          "should remain constant.")
1321
1322
1323def _build_cond_placeholders_name_prefix(cond_graph):
1324  return cond_graph.unique_name(cond_graph.name + "___redundant_placeholder")
1325
1326
1327def _duplicate_body_captures_in_cond(cond_graph, body_graph_captures):
1328  """Creates placeholders for body captures in cond_graph.
1329
1330  This is needed to match signatures of cond and body graphs.
1331
1332  Args:
1333    cond_graph: cond branch graph
1334    body_graph_captures: Tensors which were captured when building the
1335      `body_graph`.
1336  """
1337  types = [t.dtype.as_datatype_enum for t in body_graph_captures]
1338  # TODO(srbs): Providing a unique prefix does not ensure that there is no
1339  # conflict between the placeholder names and existing nodes in the graph.
1340  # However passing a list of strings may not be performant.
1341  # Ideally we should move `Graph.unique_name` to C++ or make
1342  # `Graph._names_in_use` a trie so that we can find a unique prefix.
1343  # TODO(b/143286622): This should not be required once captures are separated
1344  # from regular loop vars.
1345  with cond_graph._c_graph.get() as c_graph:
1346    placeholders = c_api.TF_CreatePlaceholders(
1347        c_graph, types,
1348        compat.as_str(_build_cond_placeholders_name_prefix(cond_graph)))
1349  placeholder_ops = [
1350      _OperationWithOutputs(ph.oper, cond_graph)
1351      for ph in placeholders
1352  ]
1353
1354  tensors = []
1355  for op, ph, dtype in zip(placeholder_ops, placeholders, types):
1356    tensor = ops.Tensor._create_with_tf_output(op, 0, dtype, ph)
1357    op._outputs = [tensor]
1358    tensors.append(tensor)
1359
1360  # Update `cond_graph._captures` and `cond_graph.inputs` to contain the
1361  # newly created placeholders.
1362  tuples = zip(body_graph_captures, tensors)
1363  keys = [id(t) for t in body_graph_captures]
1364  cond_graph._captures.update(zip(keys, tuples))
1365  cond_graph.inputs.extend(tensors)
1366
1367
1368def _copy_handle_data(src_tensors, tgt_tensors):
1369  for src_t, tgt_t in zip(src_tensors, tgt_tensors):
1370    handle_data_util.copy_handle_data(src_t, tgt_t)
1371
1372
1373def _graph_name(graph):
1374  if isinstance(graph, func_graph_module.FuncGraph):
1375    return graph.name
1376  return "Base"
1377
1378
1379def _pack_sequence_as(loop_vars_signature, flat_orig_loop_vars, loop_vars):
1380  """Like `nest.pack_sequence_as` but also replaces flows with TensorArrays."""
1381
1382  def flow_to_tensor_array(flow, ta):  # pylint: disable=missing-docstring
1383    return (tensor_array_ops.build_ta_with_new_flow(ta, flow) if isinstance(  # pylint: disable=g-long-ternary
1384        ta, tensor_array_ops.TensorArray) else flow)
1385
1386  flattened_loop_vars = [
1387      flow_to_tensor_array(*z)
1388      for z in zip(nest.flatten(loop_vars, expand_composites=True),
1389                   flat_orig_loop_vars)
1390  ]
1391  return nest.pack_sequence_as(loop_vars_signature, flattened_loop_vars,
1392                               expand_composites=True)
1393
1394
1395def _tensor_array_to_flow(loop_vars):
1396
1397  def f(maybe_ta):
1398    if isinstance(maybe_ta, tensor_array_ops.TensorArray):
1399      return maybe_ta.flow
1400    return maybe_ta
1401
1402  return nest.map_structure(f, loop_vars, expand_composites=True)
1403
1404
1405def _build_maximum_iterations_loop_var(maximum_iterations):
1406  if maximum_iterations is None:
1407    # Default value for max_num_elements to EmptyTensorList meaning that the
1408    # list size is unbounded.
1409    maximum_iterations = -1
1410  # EmptyTensorList expects `max_num_elements` to be of type int32.
1411  return ops.convert_to_tensor(
1412      maximum_iterations, dtype=dtypes.int32, name="maximum_iterations")
1413
1414
1415def _build_accumulator_name(tensor):
1416  # Tensor name may be of the form "pow/y:0". Name scope does not allow ":".
1417  return "{}/accumulator".format(tensor.name).replace(":", "_")
1418
1419
1420def _is_loop_invariant(tensor, inputs, outputs):
1421  return (any(tensor is t for t in inputs) and
1422          any(tensor is t for t in outputs))
1423
1424
1425class _OperationWithOutputs(ops.Operation):
1426  """Operation with pre-built `TF_Output`s.
1427
1428  The C API for creating the extra placeholders for the cond graph returns
1429  SWIG wrapped TF_Output* pointers which we can use directly for
1430  `Operation.outputs`. The default constructor for `Operation` does not provide
1431  a way of specifying pre-built output tensors and always creates them. This is
1432  a performance overhead. It is not clear if adding that feature to the
1433  `Operation` API would be generally useful so for now we just have our own
1434  lightweight `Operation` implementation. Note that this does not extract a
1435  stacktrace as well since we don't expect this operation to be used.
1436
1437  TODO(b/143286622): This should not be required once captures are separated
1438  from regular loop vars.
1439  """
1440
1441  def __init__(self, c_op, g):
1442    self._c_op = c_op
1443    self._graph = g
1444    self._outputs = None  # Initialized by _duplicate_body_captures_in_cond().
1445    self._id_value = g._add_op(self, self.name)
1446    self._is_stateful = False
1447
1448
1449def _set_read_only_resource_inputs_attr(op, branch_graphs):
1450  """Sets the list of resource inputs which are read-only.
1451
1452  This is used by AutomaticControlDependencies.
1453
1454  Args:
1455    op: While Operation.
1456    branch_graphs: List of branch FuncGraphs.
1457  """
1458  read_only_indices = set(range(len(op.inputs)))
1459  for branch_graph in branch_graphs:
1460    if not read_only_indices:
1461      break
1462    branch_read_only_indices = acd.get_read_only_resource_input_indices_graph(
1463        branch_graph)
1464    read_only_indices = read_only_indices.intersection(branch_read_only_indices)
1465
1466  ops.set_int_list_attr(op, acd.READ_ONLY_RESOURCE_INPUTS_ATTR,
1467                        sorted(read_only_indices))
1468
1469# pylint: enable=protected-access
1470