xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/while_v2_indexed_slices_rewriter.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Methods for rewriting while_v2 grad functions with IndexedSlices output."""
16
17from tensorflow.python.framework import constant_op
18from tensorflow.python.framework import dtypes
19from tensorflow.python.framework import func_graph
20from tensorflow.python.framework import indexed_slices
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import tensor_shape
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import gen_resource_variable_ops
25from tensorflow.python.util import nest
26
27
28def rewrite_grad_indexed_slices(grads, body_grad_graph, loop_vars,
29                                forward_inputs):
30  """Handles special case of IndexedSlices returned from while gradient.
31
32  Some gradient functions return IndexedSlices instead of a Tensor (e.g. the
33  gradient of Gather ops). When this happens in the gradient of a while body,
34  the resulting gradient body function will have mismatched inputs and outputs,
35  since the input is a single Tensor, but the IndexedSlices gets unnested into
36  three output Tensors.
37
38  This function fixes this by rewriting the gradient body to have three inputs
39  to match the three outputs, i.e., it effectively converts the input Tensor
40  into an input IndexedSlices. It also returns new `loop_vars` to reflect the
41  new inputs.
42
43  Args:
44    grads: the input gradient Tensors to the while gradient computation.
45    body_grad_graph: _WhileBodyGradFuncGraph.
46    loop_vars: list of Tensors. The inputs to body_grad_graph.
47    forward_inputs: list of Tensors. The (flat) inputs to the forward-pass While
48      op.
49
50  Returns:
51    The new loop_vars to pass to body_grad_graph.
52  """
53  # Match up body_grad_graph.structured_outputs with the corresponding
54  # forward_inputs.
55  #
56  # Note that we don't expect a gradient computation to have structured output
57  # (e.g. no nested lists), so no need to flatten
58  # body_grad_graph.structured_outputs. However, structured_outputs may still
59  # contain composite tensors such as IndexedSlices, unlike
60  # body_grad_graph.outputs, which contains flattened composite tensors.
61  inputs_with_grads = [
62      t for g, t in zip(grads, forward_inputs) if g is not None
63  ]
64  # Skip loop counter, maximum_iterations and total number of loop iterations.
65  structured_outputs = body_grad_graph.structured_outputs[3:]
66
67  for forward_input, output in zip(inputs_with_grads, structured_outputs):
68    if not isinstance(output, indexed_slices.IndexedSlices):
69      continue
70
71    if forward_input.dtype == dtypes.resource:
72      # TODO(skyewm): In theory we should use this for all captured inputs, not
73      # just resource handles (which can only be captured). We can do this by
74      # checking that forward_input is passed straight through to its output.
75      loop_vars = _rewrite_input_as_indexed_slices(body_grad_graph, output,
76                                                   forward_input, loop_vars)
77    else:
78      _rewrite_output_as_tensor(body_grad_graph, output)
79
80  return loop_vars
81
82
83def _get_tensor_index_in_iterable(iterable, t):
84  """Returns index of first occurence of `t`, raises ValueError if not found."""
85  for i, elem in enumerate(iterable):
86    if t is elem:
87      return i
88  raise ValueError(f"Element `{t!r}` is not found in iterable `{iterable!r}`.")
89
90
91def _rewrite_output_as_tensor(body_grad_graph, grad_output_slices):
92  """Rewrites grad_output_slices to be a Tensor output.
93
94  Args:
95    body_grad_graph: _WhileBodyGradFuncGraph.
96    grad_output_slices: IndexedSlices output of body_grad_graph.
97  """
98  with body_grad_graph.as_default():
99    new_output = ops.convert_to_tensor_v2(grad_output_slices)
100
101  idx = _get_tensor_index_in_iterable(body_grad_graph.structured_outputs,
102                                      grad_output_slices)
103  body_grad_graph.structured_outputs[idx] = new_output
104  body_grad_graph.outputs = func_graph.flatten(
105      body_grad_graph.structured_outputs)
106
107
108def _rewrite_input_as_indexed_slices(body_grad_graph, grad_output_slices,
109                                     forward_input, loop_vars):
110  """Rewrites grad_output_slices's corresponding input to be an IndexedSlices.
111
112  This rewrite requires that forward_input was captured in the forward loop,
113  i.e. is not a user-specified loop variable. This is important because the
114  rewrite assumes that forward_input is passed through to its corresponding
115  output unchanged. This assumption is used in _rewrite_input_as_indexed_slices,
116  which depends on the exact gradient structure produced by the input's fanout.
117
118  This can yield a more efficient computation than using
119  _rewrite_output_as_tensor, since it preserves the IndexedSlices structure
120  instead of converting the IndexedSlices to a dense Tensor.
121
122  Args:
123    body_grad_graph: _WhileBodyGradFuncGraph.
124    grad_output_slices: IndexedSlices output of body_grad_graph.
125    forward_input: the corresponding Tensor input to the forward loop.
126    loop_vars: list of Tensors. The inputs to body_grad_graph.
127
128  Returns:
129    The new loop_vars to pass to body_grad_graph.
130  """
131  # Create initial IndexedSlices that will be the input to the grad While
132  # op. This will start as zeros, and accumulate the IndexedSlices grad output.
133  # Note that because forward_input is captured and not a loop var, its incoming
134  # gradient should always be zero.
135  init_slices = _create_grad_indexed_slices_init(grad_output_slices,
136                                                 forward_input)
137
138  # Create a new version of grad_output_slices's gradient computation that uses
139  # the new IndexedSlices input instead of the original Tensor input. We'll
140  # return the new computation and leave the old computation as dead code.
141  # TODO(skyewm): considering pruning body_grad_graph to remove the old
142  # computation.
143  with body_grad_graph.as_default():
144    input_slices = indexed_slices.IndexedSlices(
145        values=body_grad_graph.capture(init_slices.values, allowlisted=True),
146        indices=body_grad_graph.capture(init_slices.indices, allowlisted=True),
147        dense_shape=body_grad_graph.capture(
148            init_slices.dense_shape, allowlisted=True))
149
150    # Remove the captured tensors from the function inputs. We'll add them back
151    # at the correct index in _update_indexed_slices_param.
152    for t in _flatten(init_slices):
153      captured_t = body_grad_graph.captures.pop(t)
154      body_grad_graph.inputs.remove(captured_t)
155
156    new_output_slices = _rewrite_grad_indexed_slices_output(
157        grad_output_slices, input_slices)
158
159  # Update body_grad_graph's inputs and outputs to reflect the new
160  # IndexedSlices computation.
161  return _update_indexed_slices_param(body_grad_graph, loop_vars, init_slices,
162                                      input_slices, new_output_slices,
163                                      grad_output_slices)
164
165
166def _create_grad_indexed_slices_init(grad_output_slices, forward_input):
167  """Creates an IndexedSlices to pass as input to the while grad function.
168
169  Args:
170    grad_output_slices: IndexedSlices. The corresponding while grad function
171      output.
172    forward_input: Tensor. The corresponding input to the forward while op.
173
174  Returns:
175    Zeros IndexedSlices, created in current Graph.
176  """
177  assert isinstance(grad_output_slices, indexed_slices.IndexedSlices)
178  assert isinstance(forward_input, ops.Tensor)
179  values_out = grad_output_slices.values
180  indices_out = grad_output_slices.indices
181
182  # Create the initial values tensor.
183  if values_out.shape.is_fully_defined():
184    values_shape = tensor_shape.TensorShape([0] +
185                                            values_out.shape.as_list()[1:])
186    values = array_ops.zeros(
187        values_shape, dtype=values_out.dtype, name="values_init")
188  else:
189    if forward_input.dtype == dtypes.resource:
190      forward_shape = gen_resource_variable_ops.variable_shape(forward_input)
191    else:
192      forward_shape = array_ops.shape(forward_input)
193    values_shape = array_ops.concat([[0], forward_shape[1:]], 0)
194    values = array_ops.zeros(
195        values_shape, dtype=values_out.dtype, name="values_init")
196
197  # Create the initial indices tensor.
198  indices = constant_op.constant([], indices_out.dtype, name="indices_init")
199
200  # Create the initial dense_shape tensor. We assume is the same shape as
201  # forward_input, since captured tensors don't change shape across loop
202  # iterations.
203  if forward_input.dtype == dtypes.resource:
204    shape = gen_resource_variable_ops.variable_shape(
205        forward_input, name="shape_init")
206  else:
207    shape = array_ops.shape(forward_input, name="shape_init")
208
209  return indexed_slices.IndexedSlices(
210      values=values, indices=indices, dense_shape=shape)
211
212
213def _rewrite_grad_indexed_slices_output(old_output_slices, new_input_slices):
214  """Creates a new version of old_output_slices with new_input_slices as input.
215
216  This method assumes that old_output_slices.{values,indices} are produced by
217  concatenating the incoming gradient Tensor input with the IndexedSlices
218  produced by the gradient computation of the while body. See
219  backprop.aggregate_indexed_slices_gradients for where these concats are
220  constructed. We build new concats that use new_input_slices instead of the
221  original Tensor input.
222
223  Args:
224    old_output_slices: original IndexedSlices output of while gradient.
225    new_input_slices: new IndexedSlices to use as input to while gradient.
226
227  Returns:
228    A new IndexedSlices to replace old_output_slices.
229  """
230
231  def rewrite(old_output, new_input):
232    assert old_output.type == "Identity"
233    concat_op = old_output.inputs[0].op
234    assert concat_op.type == "ConcatV2"
235    # Don't include axis arg
236    old_concat_args = concat_op.inputs[:-1]
237    # We assume that the original gradient input was the first argument to the
238    # concat op.
239    # TODO(skyewm): do this in a more robust way.
240    return array_ops.concat([new_input] + old_concat_args[1:], 0)
241
242  values = rewrite(old_output_slices.values.op, new_input_slices.values)
243  indices = rewrite(old_output_slices.indices.op, new_input_slices.indices)
244  return indexed_slices.IndexedSlices(
245      values=values, indices=indices, dense_shape=new_input_slices.dense_shape)
246
247
248def _update_indexed_slices_param(graph, loop_vars, init_slices, input_slices,
249                                 output_slices, old_output_slices):
250  """Updates graph with new IndexedSlices input/output.
251
252  Updates graph's metadata to output the gradient computation defined by
253  init_slices, input_slices, and output_slices, instead of outputting
254  old_output_slices. Also returns a new version of loop_vars with init_slices
255  replacing the old input.
256
257  Args:
258    graph: _WhileBodyGradFuncGraph.
259    loop_vars: the inputs to graph.
260    init_slices: the new IndexedSlices to use as input to graph.
261    input_slices: the new IndexedSlices in graph that should be fed by
262      init_slices.
263    output_slices: the new IndexedSlices in graph that should be the
264      corresponding output to input_slices.
265    old_output_slices: the IndexedSlices in graph that are currently being
266      output.
267
268  Returns:
269    New loop_vars to pass to graph.
270  """
271  structured_idx = _get_tensor_index_in_iterable(graph.structured_outputs,
272                                                 old_output_slices)
273  # We assume that the component tensors of old_output_slices appear
274  # sequentially in graph.outputs. We use the first of these tensors
275  # as the reference index.
276  flat_idx = _get_tensor_index_in_iterable(
277      graph.outputs,
278      func_graph.flatten(old_output_slices)[0])
279
280  graph.structured_outputs[structured_idx] = output_slices
281  graph.outputs = func_graph.flatten(graph.structured_outputs)
282
283  graph.inputs = (
284      graph.inputs[:flat_idx] + _flatten(input_slices) +
285      graph.inputs[flat_idx + 1:])
286
287  return loop_vars[:flat_idx] + _flatten(init_slices) + loop_vars[flat_idx + 1:]
288
289
290def _flatten(arg):
291  return nest.flatten(arg, expand_composites=True)
292