1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Methods for rewriting while_v2 grad functions with IndexedSlices output.""" 16 17from tensorflow.python.framework import constant_op 18from tensorflow.python.framework import dtypes 19from tensorflow.python.framework import func_graph 20from tensorflow.python.framework import indexed_slices 21from tensorflow.python.framework import ops 22from tensorflow.python.framework import tensor_shape 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import gen_resource_variable_ops 25from tensorflow.python.util import nest 26 27 28def rewrite_grad_indexed_slices(grads, body_grad_graph, loop_vars, 29 forward_inputs): 30 """Handles special case of IndexedSlices returned from while gradient. 31 32 Some gradient functions return IndexedSlices instead of a Tensor (e.g. the 33 gradient of Gather ops). When this happens in the gradient of a while body, 34 the resulting gradient body function will have mismatched inputs and outputs, 35 since the input is a single Tensor, but the IndexedSlices gets unnested into 36 three output Tensors. 37 38 This function fixes this by rewriting the gradient body to have three inputs 39 to match the three outputs, i.e., it effectively converts the input Tensor 40 into an input IndexedSlices. It also returns new `loop_vars` to reflect the 41 new inputs. 42 43 Args: 44 grads: the input gradient Tensors to the while gradient computation. 45 body_grad_graph: _WhileBodyGradFuncGraph. 46 loop_vars: list of Tensors. The inputs to body_grad_graph. 47 forward_inputs: list of Tensors. The (flat) inputs to the forward-pass While 48 op. 49 50 Returns: 51 The new loop_vars to pass to body_grad_graph. 52 """ 53 # Match up body_grad_graph.structured_outputs with the corresponding 54 # forward_inputs. 55 # 56 # Note that we don't expect a gradient computation to have structured output 57 # (e.g. no nested lists), so no need to flatten 58 # body_grad_graph.structured_outputs. However, structured_outputs may still 59 # contain composite tensors such as IndexedSlices, unlike 60 # body_grad_graph.outputs, which contains flattened composite tensors. 61 inputs_with_grads = [ 62 t for g, t in zip(grads, forward_inputs) if g is not None 63 ] 64 # Skip loop counter, maximum_iterations and total number of loop iterations. 65 structured_outputs = body_grad_graph.structured_outputs[3:] 66 67 for forward_input, output in zip(inputs_with_grads, structured_outputs): 68 if not isinstance(output, indexed_slices.IndexedSlices): 69 continue 70 71 if forward_input.dtype == dtypes.resource: 72 # TODO(skyewm): In theory we should use this for all captured inputs, not 73 # just resource handles (which can only be captured). We can do this by 74 # checking that forward_input is passed straight through to its output. 75 loop_vars = _rewrite_input_as_indexed_slices(body_grad_graph, output, 76 forward_input, loop_vars) 77 else: 78 _rewrite_output_as_tensor(body_grad_graph, output) 79 80 return loop_vars 81 82 83def _get_tensor_index_in_iterable(iterable, t): 84 """Returns index of first occurence of `t`, raises ValueError if not found.""" 85 for i, elem in enumerate(iterable): 86 if t is elem: 87 return i 88 raise ValueError(f"Element `{t!r}` is not found in iterable `{iterable!r}`.") 89 90 91def _rewrite_output_as_tensor(body_grad_graph, grad_output_slices): 92 """Rewrites grad_output_slices to be a Tensor output. 93 94 Args: 95 body_grad_graph: _WhileBodyGradFuncGraph. 96 grad_output_slices: IndexedSlices output of body_grad_graph. 97 """ 98 with body_grad_graph.as_default(): 99 new_output = ops.convert_to_tensor_v2(grad_output_slices) 100 101 idx = _get_tensor_index_in_iterable(body_grad_graph.structured_outputs, 102 grad_output_slices) 103 body_grad_graph.structured_outputs[idx] = new_output 104 body_grad_graph.outputs = func_graph.flatten( 105 body_grad_graph.structured_outputs) 106 107 108def _rewrite_input_as_indexed_slices(body_grad_graph, grad_output_slices, 109 forward_input, loop_vars): 110 """Rewrites grad_output_slices's corresponding input to be an IndexedSlices. 111 112 This rewrite requires that forward_input was captured in the forward loop, 113 i.e. is not a user-specified loop variable. This is important because the 114 rewrite assumes that forward_input is passed through to its corresponding 115 output unchanged. This assumption is used in _rewrite_input_as_indexed_slices, 116 which depends on the exact gradient structure produced by the input's fanout. 117 118 This can yield a more efficient computation than using 119 _rewrite_output_as_tensor, since it preserves the IndexedSlices structure 120 instead of converting the IndexedSlices to a dense Tensor. 121 122 Args: 123 body_grad_graph: _WhileBodyGradFuncGraph. 124 grad_output_slices: IndexedSlices output of body_grad_graph. 125 forward_input: the corresponding Tensor input to the forward loop. 126 loop_vars: list of Tensors. The inputs to body_grad_graph. 127 128 Returns: 129 The new loop_vars to pass to body_grad_graph. 130 """ 131 # Create initial IndexedSlices that will be the input to the grad While 132 # op. This will start as zeros, and accumulate the IndexedSlices grad output. 133 # Note that because forward_input is captured and not a loop var, its incoming 134 # gradient should always be zero. 135 init_slices = _create_grad_indexed_slices_init(grad_output_slices, 136 forward_input) 137 138 # Create a new version of grad_output_slices's gradient computation that uses 139 # the new IndexedSlices input instead of the original Tensor input. We'll 140 # return the new computation and leave the old computation as dead code. 141 # TODO(skyewm): considering pruning body_grad_graph to remove the old 142 # computation. 143 with body_grad_graph.as_default(): 144 input_slices = indexed_slices.IndexedSlices( 145 values=body_grad_graph.capture(init_slices.values, allowlisted=True), 146 indices=body_grad_graph.capture(init_slices.indices, allowlisted=True), 147 dense_shape=body_grad_graph.capture( 148 init_slices.dense_shape, allowlisted=True)) 149 150 # Remove the captured tensors from the function inputs. We'll add them back 151 # at the correct index in _update_indexed_slices_param. 152 for t in _flatten(init_slices): 153 captured_t = body_grad_graph.captures.pop(t) 154 body_grad_graph.inputs.remove(captured_t) 155 156 new_output_slices = _rewrite_grad_indexed_slices_output( 157 grad_output_slices, input_slices) 158 159 # Update body_grad_graph's inputs and outputs to reflect the new 160 # IndexedSlices computation. 161 return _update_indexed_slices_param(body_grad_graph, loop_vars, init_slices, 162 input_slices, new_output_slices, 163 grad_output_slices) 164 165 166def _create_grad_indexed_slices_init(grad_output_slices, forward_input): 167 """Creates an IndexedSlices to pass as input to the while grad function. 168 169 Args: 170 grad_output_slices: IndexedSlices. The corresponding while grad function 171 output. 172 forward_input: Tensor. The corresponding input to the forward while op. 173 174 Returns: 175 Zeros IndexedSlices, created in current Graph. 176 """ 177 assert isinstance(grad_output_slices, indexed_slices.IndexedSlices) 178 assert isinstance(forward_input, ops.Tensor) 179 values_out = grad_output_slices.values 180 indices_out = grad_output_slices.indices 181 182 # Create the initial values tensor. 183 if values_out.shape.is_fully_defined(): 184 values_shape = tensor_shape.TensorShape([0] + 185 values_out.shape.as_list()[1:]) 186 values = array_ops.zeros( 187 values_shape, dtype=values_out.dtype, name="values_init") 188 else: 189 if forward_input.dtype == dtypes.resource: 190 forward_shape = gen_resource_variable_ops.variable_shape(forward_input) 191 else: 192 forward_shape = array_ops.shape(forward_input) 193 values_shape = array_ops.concat([[0], forward_shape[1:]], 0) 194 values = array_ops.zeros( 195 values_shape, dtype=values_out.dtype, name="values_init") 196 197 # Create the initial indices tensor. 198 indices = constant_op.constant([], indices_out.dtype, name="indices_init") 199 200 # Create the initial dense_shape tensor. We assume is the same shape as 201 # forward_input, since captured tensors don't change shape across loop 202 # iterations. 203 if forward_input.dtype == dtypes.resource: 204 shape = gen_resource_variable_ops.variable_shape( 205 forward_input, name="shape_init") 206 else: 207 shape = array_ops.shape(forward_input, name="shape_init") 208 209 return indexed_slices.IndexedSlices( 210 values=values, indices=indices, dense_shape=shape) 211 212 213def _rewrite_grad_indexed_slices_output(old_output_slices, new_input_slices): 214 """Creates a new version of old_output_slices with new_input_slices as input. 215 216 This method assumes that old_output_slices.{values,indices} are produced by 217 concatenating the incoming gradient Tensor input with the IndexedSlices 218 produced by the gradient computation of the while body. See 219 backprop.aggregate_indexed_slices_gradients for where these concats are 220 constructed. We build new concats that use new_input_slices instead of the 221 original Tensor input. 222 223 Args: 224 old_output_slices: original IndexedSlices output of while gradient. 225 new_input_slices: new IndexedSlices to use as input to while gradient. 226 227 Returns: 228 A new IndexedSlices to replace old_output_slices. 229 """ 230 231 def rewrite(old_output, new_input): 232 assert old_output.type == "Identity" 233 concat_op = old_output.inputs[0].op 234 assert concat_op.type == "ConcatV2" 235 # Don't include axis arg 236 old_concat_args = concat_op.inputs[:-1] 237 # We assume that the original gradient input was the first argument to the 238 # concat op. 239 # TODO(skyewm): do this in a more robust way. 240 return array_ops.concat([new_input] + old_concat_args[1:], 0) 241 242 values = rewrite(old_output_slices.values.op, new_input_slices.values) 243 indices = rewrite(old_output_slices.indices.op, new_input_slices.indices) 244 return indexed_slices.IndexedSlices( 245 values=values, indices=indices, dense_shape=new_input_slices.dense_shape) 246 247 248def _update_indexed_slices_param(graph, loop_vars, init_slices, input_slices, 249 output_slices, old_output_slices): 250 """Updates graph with new IndexedSlices input/output. 251 252 Updates graph's metadata to output the gradient computation defined by 253 init_slices, input_slices, and output_slices, instead of outputting 254 old_output_slices. Also returns a new version of loop_vars with init_slices 255 replacing the old input. 256 257 Args: 258 graph: _WhileBodyGradFuncGraph. 259 loop_vars: the inputs to graph. 260 init_slices: the new IndexedSlices to use as input to graph. 261 input_slices: the new IndexedSlices in graph that should be fed by 262 init_slices. 263 output_slices: the new IndexedSlices in graph that should be the 264 corresponding output to input_slices. 265 old_output_slices: the IndexedSlices in graph that are currently being 266 output. 267 268 Returns: 269 New loop_vars to pass to graph. 270 """ 271 structured_idx = _get_tensor_index_in_iterable(graph.structured_outputs, 272 old_output_slices) 273 # We assume that the component tensors of old_output_slices appear 274 # sequentially in graph.outputs. We use the first of these tensors 275 # as the reference index. 276 flat_idx = _get_tensor_index_in_iterable( 277 graph.outputs, 278 func_graph.flatten(old_output_slices)[0]) 279 280 graph.structured_outputs[structured_idx] = output_slices 281 graph.outputs = func_graph.flatten(graph.structured_outputs) 282 283 graph.inputs = ( 284 graph.inputs[:flat_idx] + _flatten(input_slices) + 285 graph.inputs[flat_idx + 1:]) 286 287 return loop_vars[:flat_idx] + _flatten(init_slices) + loop_vars[flat_idx + 1:] 288 289 290def _flatten(arg): 291 return nest.flatten(arg, expand_composites=True) 292