1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""while_v2 and gradient. 16 17This is a version of while_loop that emits a single While op, as well as the 18gradient function for While ops produced by while_loop. This will eventually 19replace the current tf.while_loop implementation once it reaches feature and 20performance parity. 21""" 22import collections 23 24from tensorflow.core.framework import attr_value_pb2 25from tensorflow.python.client import pywrap_tf_session as c_api 26from tensorflow.python.eager import backprop_util 27from tensorflow.python.framework import auto_control_deps_utils as acd 28from tensorflow.python.framework import constant_op 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import func_graph as func_graph_module 31from tensorflow.python.framework import indexed_slices 32from tensorflow.python.framework import ops 33from tensorflow.python.framework import tensor_shape 34from tensorflow.python.framework import tensor_spec 35from tensorflow.python.framework import tensor_util 36from tensorflow.python.ops import array_ops 37from tensorflow.python.ops import control_flow_ops 38from tensorflow.python.ops import control_flow_util as util_v1 39from tensorflow.python.ops import control_flow_util_v2 as util 40from tensorflow.python.ops import default_gradient 41from tensorflow.python.ops import gen_functional_ops 42from tensorflow.python.ops import gen_resource_variable_ops 43from tensorflow.python.ops import gradients_util 44from tensorflow.python.ops import handle_data_util 45from tensorflow.python.ops import list_ops 46from tensorflow.python.ops import math_ops 47from tensorflow.python.ops import tensor_array_ops 48from tensorflow.python.ops import while_v2_indexed_slices_rewriter 49from tensorflow.python.util import compat 50from tensorflow.python.util import nest 51from tensorflow.python.util import object_identity 52from tensorflow.python.util import variable_utils 53 54# pylint: disable=protected-access 55 56# Controls parallelism in the presence of side-effecting ops like variable 57# operations, print, py_function, etc. Can be set to True, False, or 58# "stateless_cond" (default). 59# Note that loops without side-effecting operations always execute with maximum 60# parallelism, ignoring this setting. When False, loops with side-effecting ops 61# execute sequentially, one iteration at a time. 62# When True, loops with side-effecting ops may execute parts of different 63# iterations in parallel; caution: if the loop condition contains 64# side-effecting ops, this mode produces unspecified results. 65# Setting it to "stateless_cond" automatically sets this mode to True when 66# the loop condition is free of side-effecting ops. 67# TODO(b/152548567): Change this to "stateless_cond". 68glob_stateful_parallelism = False 69 70 71def while_loop(cond, 72 body, 73 loop_vars, 74 shape_invariants=None, 75 parallel_iterations=10, 76 maximum_iterations=None, 77 name=None, 78 return_same_structure=True, 79 back_prop=True): 80 """Like tf.while_loop, except emits a single While op.""" 81 loop_vars = variable_utils.convert_variables_to_tensors(loop_vars) 82 # Keep the original loop_vars around to know which args were TensorArrays. 83 orig_loop_vars = loop_vars 84 flat_orig_loop_vars = nest.flatten(orig_loop_vars, expand_composites=True) 85 # Cache its length since we use it at multiple places below. 86 len_orig_loop_vars = len(orig_loop_vars) 87 88 # Convert TensorArrays to their flow variables. These get converted back to 89 # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and 90 # `wrapped_body` below. 91 loop_vars = _tensor_array_to_flow(loop_vars) 92 loop_vars = nest.map_structure( 93 ops.internal_convert_to_tensor_or_indexed_slices, loop_vars, 94 expand_composites=True) 95 96 # `loop_vars_signature` is a structure of TypeSpecs and has the same 97 # structure with the `orig_loop_vars`. If `shape_invariants` is not None, its 98 # shape information comes from `shape_invariants` instead of `orig_loop_vars`. 99 # It is used to pack flattened vars into structured vars. 100 if shape_invariants is not None: 101 loop_vars_signature = nest.map_structure( 102 control_flow_ops._shape_invariant_to_type_spec, 103 loop_vars, shape_invariants) 104 else: 105 loop_vars_signature = nest.map_structure( 106 control_flow_ops._shape_invariant_to_type_spec, loop_vars) 107 108 flat_shape_invariants = nest.map_structure( 109 lambda spec: spec.shape, 110 nest.flatten(loop_vars_signature, expand_composites=True)) 111 112 if not name: 113 name = "while" 114 115 with ops.name_scope(name) as scope: 116 with ops.name_scope(None): 117 cond_name = util.unique_fn_name(scope, "cond") 118 body_name = util.unique_fn_name(scope, "body") 119 maximum_iterations_loop_var = _build_maximum_iterations_loop_var( 120 maximum_iterations) 121 loop_counter = constant_op.constant( 122 0, 123 dtype=maximum_iterations_loop_var.dtype 124 if maximum_iterations is not None else None, 125 name="loop_counter") 126 # Add loop counter needed for computing gradients. 127 loop_vars = [loop_counter, maximum_iterations_loop_var] + list(loop_vars) 128 129 func_graph_signature = ( 130 [tensor_spec.TensorSpec.from_tensor(loop_counter), 131 tensor_spec.TensorSpec.from_tensor(maximum_iterations_loop_var)] + 132 list(loop_vars_signature)) 133 134 # Automatic control dependencies are added in defuns, but not in v1 135 # graphs. Propagate that behavior here. 136 add_control_dependencies = ops.get_default_graph()._add_control_dependencies 137 138 def wrapped_cond(loop_counter, maximum_iterations_arg, *args): 139 """Extra `cond` wrapper that can handle the extra counter loop_var.""" 140 # Convert the flow variables in `args` to TensorArrays. `args` should 141 # already have the same structure as `orig_loop_vars` but currently there 142 # is no nest.zip so we call `_pack_sequence_as` which flattens `args`, 143 # converts flows in `args` to TensorArrays and packs it into the 144 # structure of `loop_vars_signature`. 145 pred = cond( 146 *_pack_sequence_as(loop_vars_signature, flat_orig_loop_vars, args)) 147 if (tensor_util.is_tf_type(pred) and 148 (pred.shape.dims is None or pred.shape.dims)): 149 pred = array_ops.squeeze_v2(pred) 150 151 if maximum_iterations is None: 152 return pred 153 else: 154 return math_ops.logical_and( 155 loop_counter < maximum_iterations_arg, pred) 156 157 # NOTE(skyewm): we set collections to the outer graph's collections for 158 # compatibility with TPUEstimator. 159 cond_graph = func_graph_module.func_graph_from_py_func( 160 cond_name, 161 wrapped_cond, 162 [], # We provide signature instead of args. 163 {}, 164 signature=func_graph_signature, 165 func_graph=util.WhileCondFuncGraph( 166 cond_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access 167 add_control_dependencies=add_control_dependencies) 168 169 if glob_stateful_parallelism == "stateless_cond": 170 stateful_parallelism = (not any( 171 op._is_stateful for op in cond_graph.get_operations())) 172 else: 173 stateful_parallelism = glob_stateful_parallelism 174 175 def wrapped_body(loop_counter, maximum_iterations_arg, *args): 176 """Loop body augmented with counter update. 177 178 Args: 179 loop_counter: Loop counter which needs to be incremented in the body. 180 maximum_iterations_arg: Maximum iterations of the loop. 181 *args: List of args 182 183 Returns: 184 A list of tensors the same length as args. 185 """ 186 # The function was created with a signature rather than tensors, so 187 # internal placeholders were created without handle data. 188 _copy_handle_data(nest.flatten(loop_vars[2:], expand_composites=True), 189 nest.flatten(args, expand_composites=True)) 190 # Capture the tensors already captured in cond_graph so that they appear 191 # in the same order in body_graph.external_captures. 192 for t in cond_graph.external_captures: 193 ops.get_default_graph().capture(t) 194 195 # Convert the flow variables in `args` to TensorArrays. `args` should 196 # already have the same structure as `orig_loop_vars` but currently there 197 # is no nest.zip so we call `_pack_sequence_as` which flattens `args`, 198 # converts flows in `args` to TensorArrays and packs it into the 199 # structure of `loop_vars_signature`. 200 outputs = body( 201 *_pack_sequence_as(loop_vars_signature, flat_orig_loop_vars, args)) 202 if not nest.is_nested(outputs): 203 outputs = [outputs] 204 try: 205 # The legacy while_loop considers list and tuple to be the same 206 # structure. 207 nest.assert_same_structure(outputs, orig_loop_vars, check_types=False, 208 expand_composites=True) 209 except ValueError: 210 # Traditionally we consider variables and tensors to be the same 211 # structure. 212 vars1 = variable_utils.convert_variables_to_tensors(outputs) 213 vars2 = variable_utils.convert_variables_to_tensors(orig_loop_vars) 214 nest.assert_same_structure(vars1, vars2, check_types=False, 215 expand_composites=True) 216 outputs = _tensor_array_to_flow(outputs) 217 218 # TODO(srbs): Update lowering code to create _Enter nodes with 219 # is_constant=True for inputs that are directly passed to outputs. 220 return [loop_counter + 1, maximum_iterations_arg] + list(outputs) 221 222 body_graph = func_graph_module.func_graph_from_py_func( 223 body_name, 224 wrapped_body, 225 [], # We provide signature instead of args. 226 {}, 227 signature=func_graph_signature, 228 func_graph=util.WhileBodyFuncGraph( 229 body_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access 230 add_control_dependencies=add_control_dependencies, 231 acd_record_initial_resource_uses=stateful_parallelism) 232 # Add external captures of body to the list of loop vars. 233 # Note that external tensors will be treated as loop invariants, i.e., 234 # the value of that tensor in each iteration is the same as it was at the 235 # beginning of the loop execution. 236 deferred_external_captures = nest.flatten( 237 [c() for c in body_graph.deferred_external_captures], 238 expand_composites=True) 239 loop_vars = ( 240 loop_vars + body_graph.external_captures + deferred_external_captures) 241 # TODO(srbs): Update lowering code to create _Enter nodes with 242 # is_constant=True for inputs that are directly passed to outputs. 243 body_graph.outputs.extend(body_graph.internal_captures) 244 body_graph.outputs.extend(body_graph.deferred_internal_captures) 245 246 # Capture the extra `external_captures` of `body_graph` in `cond_graph` so 247 # that it expects to receive those as arguments. 248 with cond_graph.as_default(): 249 num_cond_captures = len(cond_graph.external_captures) 250 assert (cond_graph.external_captures == 251 body_graph.external_captures[:num_cond_captures]) 252 _duplicate_body_captures_in_cond( 253 cond_graph, body_graph.external_captures[num_cond_captures:] + 254 deferred_external_captures) 255 256 # Make sure that the shapes of the loop outputs are compatible with the 257 # shape invariants, or the shapes of the loop vars if the invariants are not 258 # specified. 259 num_flattened_outputs = len(nest.flatten(orig_loop_vars, 260 expand_composites=True)) 261 # First var is loop counter and second var is maximum_iterations. 262 first_loop_var_index = 2 263 _check_shapes_compat( 264 body_graph.outputs[first_loop_var_index:first_loop_var_index + 265 num_flattened_outputs], 266 flat_shape_invariants, 267 nest.flatten(loop_vars[first_loop_var_index:first_loop_var_index + 268 len_orig_loop_vars], expand_composites=True)) 269 270 num_original_outputs = len(body_graph.outputs) 271 if back_prop and util.output_all_intermediates(): 272 # Export all tensors in the loop body that may be needed for gradient 273 # computation. We do this by accumulating the intermediate values in 274 # TensorLists. 275 intermediate_tensors = _get_intermediates(body_graph) 276 277 for intermediate_tensor in intermediate_tensors: 278 tensor_list = list_ops.empty_tensor_list( 279 element_dtype=intermediate_tensor.dtype, 280 element_shape=intermediate_tensor.shape, 281 max_num_elements=maximum_iterations) 282 loop_vars.append(tensor_list) 283 with cond_graph.as_default(): 284 # Add a placeholder to cond_graph's inputs corresponding to the 285 # tensor_list. 286 cond_graph.capture(tensor_list) 287 with body_graph.as_default(): 288 # Push the intermediate tensor to the tensor list. This captures the 289 # `tensor_list` as well. 290 appended_tensor_list = list_ops.tensor_list_push_back( 291 tensor_list, intermediate_tensor) 292 # Add this modified tensor list to the list of outputs. 293 body_graph.outputs.append(appended_tensor_list) 294 295 flattened_loop_vars = nest.flatten(loop_vars, expand_composites=True) 296 _check_num_inputs_outputs(cond_graph, body_graph, 297 len(flattened_loop_vars)) 298 _check_inputs_outputs_types_match(body_graph, flattened_loop_vars) 299 300 with ops.control_dependencies( 301 list(cond_graph.control_captures) + list(body_graph.control_captures)): 302 output_shapes = [t.shape for t in body_graph.outputs] 303 orig_loop_vars_range = slice(first_loop_var_index, 304 first_loop_var_index + num_flattened_outputs) 305 output_shapes[orig_loop_vars_range] = flat_shape_invariants 306 307 outputs = _build_while_op( 308 flattened_loop_vars, 309 cond_graph, 310 body_graph, 311 output_shapes=output_shapes, 312 parallel_iterations=parallel_iterations, 313 name=scope, 314 num_original_outputs=num_original_outputs, 315 stateful_parallelism=stateful_parallelism) 316 if not ops.get_default_graph().building_function: 317 # In V1 graph mode, return identities for each output of the While op, 318 # rather than the output of the While op directly. This makes pruning work 319 # if the output of while_loop() is fetched: the lowering pass converts the 320 # While outputs into IdentityN outputs, which if fetched will cause all 321 # ops in the body to be run (since it takes all exit ops as input). After 322 # lowering, each output identity op will end up with only the appropriate 323 # exit op as input. 324 outputs = tuple(array_ops.identity(t) for t in outputs) 325 326 output_loop_vars = outputs[first_loop_var_index:first_loop_var_index + 327 num_flattened_outputs] 328 if not back_prop: 329 output_loop_vars = [array_ops.stop_gradient(t) for t in output_loop_vars] 330 outputs = _pack_sequence_as( 331 loop_vars_signature, flat_orig_loop_vars, output_loop_vars) 332 333 if return_same_structure: 334 return outputs 335 336 flattened_outputs = nest.flatten(outputs, expand_composites=True) 337 if len(flattened_outputs) == 1: 338 return flattened_outputs[0] 339 else: 340 return outputs 341 342 343@ops.RegisterGradient("StatelessWhile") 344@ops.RegisterGradient("While") 345def _WhileGrad(op, *grads): # pylint: disable=invalid-name 346 """The gradient of a While op produced by while_loop.""" 347 # Note that op is not always the same as while_op because the gradient tape, 348 # for eager mode compatibility, forgets information about the proper op. Since 349 # the loop cannot run in eager mode, however, we can safely introspect into 350 # the graph here. 351 while_op = op.outputs[0].op 352 cond_graph = _get_graph(while_op, "cond", "_cond_graph") 353 body_graph = _get_graph(while_op, "body", "_body_graph") 354 orig_num_params = len(body_graph.outputs) 355 356 maximum_iterations = op.inputs[1] 357 parallel_iterations = op.get_attr("parallel_iterations") 358 359 try: 360 num_original_outputs = while_op.get_attr("_num_original_outputs") 361 except: # pylint: disable=bare-except 362 num_original_outputs = len(while_op.outputs) 363 364 try: 365 stateful_parallelism = while_op.get_attr("_stateful_parallelism") 366 except: # pylint: disable=bare-except 367 stateful_parallelism = False 368 369 num_intermediates = len(while_op.outputs) - num_original_outputs 370 grads = [ 371 _preprocess_grad(grad, body_out, while_in, while_out) # pylint: disable=g-complex-comprehension 372 for grad, body_out, while_in, while_out in zip( 373 grads[:num_original_outputs], 374 body_graph.outputs[:num_original_outputs], 375 while_op.inputs[:num_original_outputs], 376 while_op.outputs[:num_original_outputs]) 377 ] + [None] * num_intermediates 378 379 # Skip gradients with respect to the captures whenever possible. 380 if "skip_input_indices" in op.__dict__ and op.skip_input_indices is not None: 381 captures_start_index = ( 382 len(body_graph.inputs) - len(body_graph.internal_captures)) 383 for i in op.skip_input_indices: 384 if i >= captures_start_index: 385 grads[i] = None 386 387 # We compute the gradient for the sub-graph between trainable ys and xs 388 # with non-None incoming gradients. We later pad the None's to the list of 389 # outputs. 390 ys, xs, non_none_grads = zip(*[(y, x, grad) for (y, x, grad) in zip( 391 body_graph.outputs, body_graph.inputs, grads) if grad is not None]) 392 393 body_grad_graph, args = _create_grad_func( 394 ys, xs, non_none_grads, cond_graph, body_graph, 395 util.unique_grad_fn_name(body_graph.name), op, maximum_iterations, 396 stateful_parallelism) 397 398 if body_grad_graph.while_op_needs_rewrite: 399 # Modify 'op' to output the intermediate accumulators needed by the grad 400 # function. 401 # NOTE(skyewm): if there are any active sessions, this modification to `op` 402 # may make them unrunnable! 403 404 cond_graph.name += "_rewritten" 405 body_graph.name += "_rewritten" 406 407 # `body_grad_graph.extra_inputs` here is equivalent to skimming off the new 408 # `body_graph.external_captures` added during `_create_grad_func`. 409 new_inputs = body_grad_graph.extra_inputs 410 new_outputs = body_graph.outputs[orig_num_params:] 411 412 while_op._set_func_attr("cond", util.create_new_tf_function(cond_graph)) 413 while_op._set_func_attr("body", util.create_new_tf_function(body_graph)) 414 if len(body_graph.output_types) != len(while_op.inputs) + len(new_inputs): 415 # Continuing leads to an invalid graph with disconnected inputs. 416 raise AssertionError( 417 "Inputs and outputs constructed for the forward op of a While " 418 "gradient don't match with 'output_types' at " 419 f"{len(body_graph.output_types)},'inputs' at length " 420 f"{len(while_op.inputs)}, and 'new_inputs' at length " 421 f"{len(new_inputs)}. This doesn't make sense, please file a bug.") 422 while_op._set_type_list_attr("T", body_graph.output_types) 423 while_op._set_shape_list_attr("output_shapes", body_graph.output_shapes) 424 while_op._add_while_inputs(new_inputs) 425 while_op._add_outputs([t.dtype for t in new_outputs], 426 [t.shape for t in new_outputs]) 427 _copy_handle_data(new_outputs, while_op.outputs[orig_num_params:]) 428 429 # Do not ignore grads wrt extra outputs when computing higher order 430 # derivatives. 431 while_op._set_attr("_num_original_outputs", 432 attr_value_pb2.AttrValue(i=len(while_op.outputs))) 433 434 captured_inputs = _resolve_grad_captures(body_graph, body_grad_graph, 435 while_op) 436 loop_vars = args + captured_inputs 437 438 # This modifies body_grad_graph. 439 loop_vars = while_v2_indexed_slices_rewriter.rewrite_grad_indexed_slices( 440 grads, body_grad_graph, loop_vars, while_op.inputs) 441 442 def grad_cond(counter, unused_maximum_iterations_arg, forward_loop_iters, 443 *unused_args): 444 return counter < forward_loop_iters 445 446 grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name) 447 cond_grad_graph = func_graph_module.func_graph_from_py_func( 448 grad_cond_name, grad_cond, loop_vars, {}, 449 func_graph=util.WhileCondFuncGraph(grad_cond_name)) 450 451 _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars)) 452 453 outputs = _build_while_op( 454 loop_vars, 455 cond_grad_graph, 456 body_grad_graph, 457 output_shapes=[t.shape for t in body_grad_graph.outputs], 458 parallel_iterations=parallel_iterations, 459 name="%s_grad" % while_op.name, 460 num_original_outputs=len(body_grad_graph.outputs), 461 stateful_parallelism=stateful_parallelism) 462 463 # See comment in while_loop. 464 outputs = [array_ops.identity(t) for t in outputs] 465 return _get_structured_grad_output(outputs, grads, body_grad_graph) 466 467 468def _build_while_op(loop_vars, cond_graph, body_graph, output_shapes, 469 parallel_iterations, name, num_original_outputs, 470 stateful_parallelism): 471 """Builds the functional StatelessWhile/While op.""" 472 cond_stateful_ops = [ 473 op for op in cond_graph.get_operations() if op._is_stateful 474 ] 475 body_stateful_ops = [ 476 op for op in body_graph.get_operations() if op._is_stateful 477 ] 478 if (cond_stateful_ops or body_stateful_ops): 479 op_fn = gen_functional_ops._while 480 else: 481 op_fn = gen_functional_ops.stateless_while 482 483 def _make_op(inputs): 484 while_op, tensors = util.get_op_and_outputs(op_fn( 485 inputs, 486 util.create_new_tf_function(cond_graph), 487 util.create_new_tf_function(body_graph), 488 output_shapes=output_shapes, 489 parallel_iterations=parallel_iterations, 490 name=name)) 491 _copy_handle_data(body_graph.outputs, tensors) 492 util.maybe_set_lowering_attr(while_op) 493 util.maybe_propagate_compile_time_consts_in_xla(while_op) 494 _set_read_only_resource_inputs_attr(while_op, [cond_graph, body_graph]) 495 # This is needed so we do not compute derivative wrt these extra outputs. 496 while_op._set_attr("_num_original_outputs", 497 attr_value_pb2.AttrValue(i=num_original_outputs)) 498 while_op._set_attr("_stateful_parallelism", 499 attr_value_pb2.AttrValue(b=stateful_parallelism)) 500 # The while op may be created inside a tf.function, in which case ops 501 # needs to capture "through" it when taking gradients; outer_graph is used 502 # as a sanity check that capturing only happens from parent to child. 503 cond_graph.outer_graph = ops.get_default_graph() 504 body_graph.outer_graph = ops.get_default_graph() 505 while_op._cond_graph = cond_graph 506 while_op._body_graph = body_graph 507 return tensors 508 return util.run_as_function_for_tape_gradients(_make_op, loop_vars) 509 510 511def _get_intermediates(func_graph): 512 """Returns all tensors in `func_graph` that should be accumulated.""" 513 # We currently accumulate output tensors of most ops in the function and rely 514 # on the pruning pass to get rid of the unused accumulators at runtime. 515 # However, this can bloat the GraphDef and make debugging harder so we perform 516 # some optimizations. 517 # 518 # Optimization we currently perform: 519 # 1. We do not accumulate tensors which already have an accumulator 520 # in the loop body. 521 # 2. We do not accumulate outputs of Identity nodes. When building the 522 # FuncGraph, we add an Identity node for each output (see 523 # `AutomaticControlDependencies.mark_as_return`). Accumulating outputs 524 # of all these nodes bloats the GraphDef quite a bit so we remove those. 525 # Since the gradient of an Identity node does not rely on its forward op's 526 # input this is safe to do. 527 # 528 # Other possible optimizations: 529 # 1. Only accumulate tensors that will be required by the backward pass. 530 # This will require running the gradient pass and hence would increase the 531 # graph building time for the forward pass. 532 # 2. Do not accumulate Const nodes created inside the loop body. 533 # 3. Do not accumulate loop vars that are returned as-is just like captured 534 # tensors. 535 intermediates = [] 536 reverse_captures = dict((v.ref(), k) for k, v in func_graph.captures) 537 538 for op in func_graph.get_operations(): 539 if op.type == "Identity": 540 continue 541 # Accumulating mutexes can cause deadlock. 542 if op.type == "MutexLock": 543 continue 544 for o in op.outputs: 545 if (o is not func_graph.inputs[0] and # Loop counter. 546 o.dtype != dtypes.resource and # Do not accumulate resource tensors. 547 _get_accumulator(o) is None and # Has existing accumulator. 548 o.ref() not in reverse_captures 549 ): # Captured value, hence loop invariant. 550 intermediates.append(o) 551 return intermediates 552 553 554def _preprocess_grad(grad, body_graph_output, while_op_input, while_op_output): 555 """Returns the initial gradient to be used for a given output tensor. 556 557 Args: 558 grad: the original gradient Tensor passed to the gradient function. 559 body_graph_output: the corresponding Tensor in the body graph. 560 while_op_input: the corresponding Tensor input of the While op. 561 while_op_output: the corresponding Tensor output of the While op. 562 563 Returns: 564 A Tensor or None. 565 """ 566 # Set the incoming gradient of non-trainable inputs to None. It is possible 567 # that we receive non-None gradients for non-trainable types in nested while 568 # loops because we accumulate outputs of the inner while as variant tensors 569 # which are trainable and hence receive zeros_like tensors in the gradient 570 # pass. The non-trainable tensors then receive the popped zeros tensor from 571 # this zeros variant. The gradient for the loop vars corresponding to these 572 # tensors is None or zeros (this happens only if the loop var is accumulated 573 # as well) in _grad_fn so we reset these. 574 # TODO(b/118712257): Remove once we can handle None output grads in _grad_fn. 575 if not _is_trainable(body_graph_output): 576 return None 577 578 # GradientTape initializes resource and variant grads as None instead of 579 # zeros. Set to zeros so _GradientsHelper computes the gradients instead of 580 # returning None. 581 # TODO(b/143286622): The supports_default_grad check is needed 582 # because While op emits non-differentiable resource tensors 583 # as outputs. Remove this check when that is not the case. 584 # Note: We use `while_op_input` instead of `while_op_output` for the call 585 # to `supports_default_grad` because `while_op_output` may be missing 586 # handle_data if the While is in a restored saved model. 587 if (while_op_output.dtype in (dtypes.resource, dtypes.variant) and 588 default_gradient.supports_default_grad(while_op_input) and grad is None): 589 return _zeros_like(while_op_input, while_op_output) 590 591 # Convert IndexedSlices to dense tensors since it is unlikely that downstream 592 # gradient functions with properly handle indexed slices. This is similar to 593 # what we do in tf.function gradients. 594 if isinstance(grad, indexed_slices.IndexedSlices): 595 return ops.convert_to_tensor(grad) 596 597 return grad 598 599 600# TODO(skyewm): make this return constants if op_output's shape is fully 601# defined (this can be done by checking the "shape" attr of resource vars). 602def _zeros_like(op_input, op_output): 603 """Like array_ops.zeros_like() but also accepts resource var handles.""" 604 if op_output.dtype == dtypes.resource: 605 # Note: We use `op_input` instead of `op_output` to get the zeros dtype 606 # because `op_output` may be missing handle_data if the While is in a 607 # restored saved model. 608 return array_ops.zeros( 609 gen_resource_variable_ops.variable_shape(op_output), 610 dtype=default_gradient.get_zeros_dtype(op_input)) 611 return array_ops.zeros_like(op_output) 612 613 614def _is_trainable(tensor): 615 """Returns whether the given tensor is trainable.""" 616 if not backprop_util.IsTrainable(tensor): 617 return False 618 619 # Special case: untrainable accumulator output. The gradients algorithm 620 # doesn't know about tensor lists of untrainable elements. In theory the 621 # tensor list gradient functions should return None as appropriate, but 622 # because we can't return None from the gradient function we filter out 623 # untrainable accumulator output here to avoid computing the gradient at all. 624 if tensor.op.type == "TensorListPopBack" and tensor.value_index == 0: 625 assert tensor.dtype == dtypes.variant 626 element_type = tensor.op.get_attr("element_dtype") 627 return backprop_util.IsTrainable(element_type) 628 629 return True 630 631 632def _get_graph(while_op, func_attr_name, attr_graph_name): 633 """Returns `FuncGraph` for the given function attribute. 634 635 Args: 636 while_op: The While Operation. 637 func_attr_name: string 638 attr_graph_name: cached forward graph name 639 640 Returns: 641 `FuncGraph` 642 """ 643 func_graph = getattr(while_op, attr_graph_name, None) 644 if func_graph is None: 645 # TODO(srbs): Handle TensorShapeProto in function_def_to_graph.input_shapes. 646 input_shapes = [ 647 tensor_shape.TensorShape(s) for s in while_op.get_attr("output_shapes") 648 ] 649 func_name = while_op.get_attr(func_attr_name).name 650 func_graph = util.get_func_graph(while_op, input_shapes, func_name) 651 func_graph._while = while_op 652 return func_graph 653 654 655def _create_grad_func(ys, xs, grads, cond_graph, body_graph, name, while_op, 656 maximum_iterations, stateful_parallelism): 657 """Builds and returns the gradient FuncGraph of `func_graph` and its args. 658 659 The returned grad_func_graph must be called with the returned 660 args + grad_func_graph.captures. 661 662 Args: 663 ys: A `Tensor` or list of tensors to be differentiated. 664 xs: A `Tensor` or list of tensors to be used for differentiation. 665 grads: The incoming grads for `ys`. 666 cond_graph: FuncGraph for the forward cond function. 667 body_graph: FuncGraph for the forward body function. 668 name: Name of the returned gradient function. 669 while_op: The forward While op. 670 maximum_iterations: Tensor. The maximum number of iterations. 671 stateful_parallelism: Bool, see tf.while_loop. 672 673 Returns: 674 2-tuple of (grad_func_graph, args). 675 """ 676 assert len(ys) == len(grads) 677 678 total_iters = while_op.outputs[0] 679 counter = constant_op.constant( 680 0, dtype=total_iters.dtype, name="grad_counter") 681 682 # Build frozen sets so that we do not have linear time lookups in 683 # `_is_loop_invariant`. Note: `body_graph.inputs` and `body_graph.outputs` 684 # may get updated during gradient computation because we add accumulators to 685 # the forward op. However, those are not loop invariants so wouldn't affect 686 # the output of `_is_loop_invariant`. Also we would never attempt to capture 687 # those accumulators so `_is_loop_invariant` should never receive those new 688 # tensors as args. 689 body_graph_inputs = object_identity.ObjectIdentitySet(body_graph.inputs) 690 body_graph_outputs = object_identity.ObjectIdentitySet(body_graph.outputs) 691 692 args = [counter, maximum_iterations, total_iters] + list(grads) 693 # Note: The returned function does not have `args` in the list of 694 # `external_captures`. 695 grad_func_graph = func_graph_module.func_graph_from_py_func( 696 name, 697 lambda *args: _grad_fn(ys, xs, args, body_graph), 698 args, {}, 699 func_graph=_WhileBodyGradFuncGraph(name, cond_graph, body_graph, 700 maximum_iterations, while_op, 701 body_graph_inputs, body_graph_outputs), 702 acd_record_initial_resource_uses=stateful_parallelism) 703 704 # Update the list of outputs with tensors corresponding to the captured 705 # tensors. We capture 3 types of tensors when building the grad fn: 706 # 1. Accumulators for forward graph intermediates which are not loop 707 # invariants. The outputs corresponding to these are populated in 708 # `internal_capture_to_output` by `_WhileBodyGradFuncGraph`. 709 # 2. Resources, which are output as is. 710 # 3. Forward graph loop invariants, which are output as is. 711 for external_capture, internal_capture in grad_func_graph.captures: 712 if (ops.tensor_id(internal_capture) 713 in grad_func_graph.internal_capture_to_output): 714 new_output = grad_func_graph.internal_capture_to_output[ops.tensor_id( 715 internal_capture)] 716 else: 717 raise ValueError( 718 f"Tensor {str(internal_capture)} which captures " 719 f"{str(external_capture)} is in list of " 720 f"internal_captures but not in internal_capture_to_output.") 721 grad_func_graph.outputs.append(new_output) 722 grad_func_graph.structured_outputs.append(new_output) 723 724 return grad_func_graph, args 725 726 727def _grad_fn(ys, xs, args, func_graph): 728 """Computes the gradient of `func_graph` in the current graph. 729 730 This function builds the gradient graph of the corresponding forward-pass 731 `func_graph` by differentiating `func_graph`'s outputs w.r.t. its inputs. 732 733 Args: 734 ys: A `Tensor` or list of tensors to be differentiated. 735 xs: A `Tensor` or list of tensors to be used for differentiation. 736 args: The input arguments. 737 args[0] - Loop counter 738 args[1] - Total number of iterations. 739 args[2] - maximum_iterations. 740 args[3:] - Incoming gradients for `ys`. 741 func_graph: function.FuncGraph. The corresponding forward-pass function. 742 743 Returns: 744 The output gradient Tensors. 745 """ 746 grad_ys = args[3:] 747 748 # Build the gradient graph. Note that this builds the gradient computation of 749 # func_graph in the current graph, which requires capturing tensors from 750 # func_graph. The captured func_graph tensors are resolved to external tensors 751 # after the forward While op has been rewritten in _resolve_grad_captures. 752 # TODO(srbs): Mark GradientsHelper as public? 753 grad_outs = gradients_util._GradientsHelper( 754 ys, xs, grad_ys=grad_ys, src_graph=func_graph, 755 unconnected_gradients="zero") 756 757 # TODO(b/118712257): Handle the case when grad_outs has None's e.g. when there 758 # is a tf.StopGradient in the loop body. 759 assert all(g is not None for g in grad_outs) 760 counter = args[0] 761 maximum_iterations = args[1] 762 total_iters = args[2] 763 return [counter + 1, maximum_iterations, total_iters] + grad_outs 764 765 766def _resolve_grad_captures(body_graph, body_grad_graph, while_op): 767 """Returns the tensors to pass as captured inputs to `body_grad_graph`. 768 769 `body_grad_graph` may have external references to: 770 1. Its outer graph containing the input gradients. These are left as-is. 771 2. Accumulators captured from the forward-pass graph. These should have been 772 added as `while_op` outputs after the gradient graph was built. We replace 773 these with the corresponding output of `while_op`, i.e. a tensor in 774 `body_graph.outer_graph`. In the case of nested control flow or functions, 775 the gradient logic handling `body_grad_graph.outer_graph` will make sure 776 the tensor from `body_graph.outer_graph` is also correctly captured. 777 778 Args: 779 body_graph: FuncGraph. The forward-pass body function. 780 body_grad_graph: FuncGraph. The body gradients function. 781 while_op: The forward-pass While Operation calling `body_graph`. 782 783 Returns: 784 A list of input tensors to be passed as the captured inputs to 785 `body_grad_graph`. 786 """ 787 new_capture_inputs = [] 788 for t in body_grad_graph.external_captures: 789 # Resolve tensors captured from the forward graph to the outputs of the 790 # forward while_op. 791 if t.graph == body_graph: 792 # Captured accumulator or loop invariant. 793 for i, output in enumerate(t.graph.outputs): 794 if output is t: 795 t = while_op.outputs[i] 796 break 797 798 # Note: We rely on the capturing logic of the gradient While op graph to 799 # correctly capture the tensors in `body_graph.outer_graph`. Both cond_v2 800 # and while_v2 handle this while building their gradient functions. 801 assert t.graph == body_graph.outer_graph 802 803 new_capture_inputs.append(t) 804 return new_capture_inputs 805 806 807def _get_structured_grad_output(outputs, grads, body_grad_graph): 808 """Returns the values that should be returned from the while grad function. 809 810 Args: 811 outputs: the raw Tensor outputs of the grad While op. 812 grads: the input gradients to the gradient function. 813 body_grad_graph: _WhileBodyGradFuncGraph. 814 815 Returns: 816 A list of gradient values. May include Nones. 817 """ 818 result = [] 819 # outputs[0] is the loop counter. 820 # outputs[1] is maximum_iterations. 821 # outputs[2] is the total number of loop iterations. 822 outputs_idx = 3 823 structured_outputs_idx = 3 824 for g in grads: 825 # Set None as the output gradient for tensors with None input gradient. 826 if g is None: 827 result.append(None) 828 continue 829 output = body_grad_graph.structured_outputs[structured_outputs_idx] 830 structured_outputs_idx += 1 831 if isinstance(output, indexed_slices.IndexedSlices): 832 # TODO(skyewm): is there a more robust way to determine the order of 833 # flattened IndexedSlices components? 834 result.append(indexed_slices.IndexedSlices( 835 values=outputs[outputs_idx], 836 indices=outputs[outputs_idx + 1], 837 dense_shape=outputs[outputs_idx + 2])) 838 outputs_idx += 3 839 else: 840 assert isinstance(output, ops.Tensor) 841 result.append(outputs[outputs_idx]) 842 outputs_idx += 1 843 844 return result 845 846 847def _get_accumulator(tensor): 848 r"""Returns TensorList if any containing accumulated values of tensor. 849 850 We try to find a pattern of the form: 851 852 input_tl tensor 853 \ / 854 (TensorListPushBack) 855 | 856 output_tl 857 858 which satisfies the following conditions: 859 860 1. input_tl must be in tensor.graph.inputs. 861 2. output_tl or Identity(output_tl) must be in tensor.graph.outputs. 862 3. tensor.graph.input_index(input_tl) == tensor.graph.output_index(output_t). 863 864 output_tl or Identity(output_tl) (whichever is in tensor.graph.outputs) is 865 returned if such a pattern is found else None is returned. 866 867 Args: 868 tensor: The Tensor to be accumulated. 869 870 Returns: 871 A variant tensor in the same graph as `tensor` or None if no accumulator is 872 found. 873 """ 874 assert isinstance(tensor.graph, func_graph_module.FuncGraph) 875 876 def get_func_graph_output(t): 877 """Returns t or Identity(t) whichever exists in graph outputs else None.""" 878 for output in tensor.graph.outputs: 879 if output is t: 880 return t 881 # tf.defun adds an Identity for each output, check whether that is the case. 882 identity_op = t.consumers()[0] 883 if (identity_op.type == "Identity" and 884 any(identity_op.outputs[0] is t for t in tensor.graph.outputs)): 885 return identity_op.outputs[0] 886 return None 887 888 for consumer in tensor.consumers(): 889 # Find the consumer that is a TensorListPushBack node whose TensorList input 890 # is in the list of function inputs. 891 if consumer.type != "TensorListPushBack": 892 continue 893 894 accum_input_idx = -1 895 for accum_input_idx, inp in enumerate(tensor.graph.inputs): 896 if inp is consumer.inputs[0]: 897 break 898 else: 899 continue 900 901 output = get_func_graph_output(consumer.outputs[0]) 902 if output is None: 903 # The TensorList output of `consumer` is not in the list of function 904 # outputs. 905 continue 906 907 for accum_output_idx, out in enumerate(tensor.graph.outputs): 908 if out is output: 909 if accum_input_idx == accum_output_idx: 910 return output 911 break 912 913 return None 914 915 916OptimizedReductionOpsCacheKey = collections.namedtuple( 917 "OptimizedReductionOpsCacheKey", [ 918 "op_type", 919 "inputs", 920 "dtypes", 921 "input_types", 922 "name", 923 "attrs", 924 "op_def", 925 "compute_device", 926 ]) 927 928 929class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph): 930 """FuncGraph for the gradient function of the body of a While op. 931 932 Contains the logic for capturing the tensors from the body of the forward 933 While op which is as follows: 934 1. If the tensor is of resource type (these are not accumulated): 935 a. Ensure that the tensor is a loop invariant, i.e., it exists in both loop 936 inputs and outputs at the same index. 937 b. Lookup the corresponding resource tensor in the forward outer graph and 938 try to capture that. 939 2. If the tensor is not of resource type: 940 a. Create an accumulator for that tensor and output it from the forward 941 pass. Note this also requires adding it as an input to the forward pass. 942 b. Capture the accumulator from the forward pass in this FuncGraph. This 943 will later be resolved to the correct output of the forward While op. 944 c. Pop a value from the captured placeholder and use it as the captured 945 value for the forward pass tensor. 946 947 This only allows capturing tensors in the forward graph. A ValueError is 948 raised if an attempt is made to capture a tensor not in the forward graph. 949 To manually capture a tensor that is not in the forward graph, call `capture` 950 with `allowlisted=True`. 951 952 Note: The `captures` dict does not contain the forward tensor since it is not 953 directly captured. It contains the accumulator corresponding to this forward 954 tensor. 955 956 Attributes: 957 while_op_needs_rewrite: True if any non-resource intermediates were 958 captured, meaning the forward While op needs to be rewritten to output the 959 corresponding accumulators. 960 extra_inputs: list of EmptyTensorList tensors to be used as initial input to 961 the new accumulators in the forward graph. It may also contain external 962 captures of the custom gradient function. 963 internal_capture_to_output: dict from a tensor_id(captured placeholder) to 964 the corresponding tensor that needs to be added to the list of outputs. 965 For instance, when capturing an accumulator TensorList this contains the 966 TensorList obtained after popping a tensor from the list. Other entries 967 in this dict are expected, though not enforced, to be identities. 968 This dict is needed because these output tensors need to be added to 969 FuncGraph.outputs "after" the tensors returned from the gradient function. 970 """ 971 972 def __init__(self, name, forward_cond_graph, forward_body_graph, 973 maximum_iterations, forward_while_op, body_graph_inputs, 974 body_graph_outputs): 975 super(_WhileBodyGradFuncGraph, self).__init__(name) 976 self.extra_inputs = [] 977 self.internal_capture_to_output = {} 978 # FuncGraph for the body of the forward While op. 979 self._forward_graph = forward_body_graph 980 # FuncGraph for the cond of the forward While op. 981 self._forward_cond_graph = forward_cond_graph 982 self._maximum_iterations = maximum_iterations 983 self._forward_while_op = forward_while_op 984 # Dict from forward intermediate tensor to its indirectly captured tensor 985 # in this graph. Indirect capturing happens in two ways: 986 # 1. For non-resource tensors we capture their accumulators from the forward 987 # outer graph and pop values from that accumulator inside this graph 988 # using TensorListPopBack. 989 # 2. For resource tensors we directly capture their corresponding tensor 990 # in the forward outer graph. 991 self._indirect_captures = {} 992 993 @property 994 def while_op_needs_rewrite(self): 995 return self.extra_inputs 996 997 def _create_op_internal( 998 self, 999 op_type, 1000 inputs, 1001 dtypes=None, # pylint: disable=redefined-outer-name 1002 input_types=None, 1003 name=None, 1004 attrs=None, 1005 op_def=None, 1006 compute_device=True): 1007 # For a reduction op, if op is in the gradient body graph and its input is 1008 # from the forward graph, moving op to the forward graph means we would 1009 # store the tensor after the reduction as opposed to the tensor before 1010 # reduction, and therefore could significantly reduce memory consumption. 1011 # For now, we do this only for a few ops. 1012 # 1013 # We don't do this if any input tensor has already been accumulated. This 1014 # can happen if we output all intermediates in the forward pass. 1015 # 1016 # If in XLA context, do not move constant ops to forward pass as pushing to 1017 # and popping from a TensorList removes the constant property of an op and 1018 # breaks XLA compilation, which requires certain inputs to be compile-time 1019 # constant for certain ops. 1020 # 1021 # This optimization is currently also disabled when under a persistent tape, 1022 # since it leads to an unbounded number of side outputs. With caching it may 1023 # be possible to re-enable it. 1024 optimized_reduction_ops = { 1025 "Shape", "Size", "Rank", "TensorListElementShape", "TensorListLength" 1026 } 1027 if (op_type in optimized_reduction_ops and 1028 not util.output_all_intermediates() and 1029 all(input.graph is self._forward_graph for input in inputs) and 1030 all(_get_accumulator(input) is None for input in inputs) and 1031 not util_v1.GraphOrParentsInXlaContext(self._forward_graph) and 1032 not util.graph_wrapped_for_higher_order_tape_gradients( 1033 self._forward_graph)): 1034 return self._move_op_to_forward_graph( 1035 op_type, 1036 inputs, 1037 dtypes=dtypes, 1038 input_types=input_types, 1039 name=name, 1040 attrs=attrs, 1041 op_def=op_def, 1042 compute_device=compute_device) 1043 1044 return super(_WhileBodyGradFuncGraph, self)._create_op_internal( 1045 op_type, 1046 inputs, 1047 dtypes=dtypes, 1048 input_types=input_types, 1049 name=name, 1050 attrs=attrs, 1051 op_def=op_def, 1052 compute_device=compute_device) 1053 1054 def _move_op_to_forward_graph( 1055 self, 1056 op_type, 1057 inputs, 1058 dtypes=None, # pylint: disable=redefined-outer-name 1059 input_types=None, 1060 name=None, 1061 attrs=None, 1062 op_def=None, 1063 compute_device=True): 1064 # We have a cache of reduction ops that have already been moved to the 1065 # forward graph, and we will check it first to avoid moving an op twice. 1066 if not hasattr(self._forward_graph, "_optimized_reduction_ops_cache"): 1067 self._forward_graph._optimized_reduction_ops_cache = {} 1068 cache_key = self._get_optimized_reduction_ops_cache_key( 1069 op_type, inputs, dtypes, input_types, name, attrs, op_def, 1070 compute_device) 1071 cached_op = self._forward_graph._optimized_reduction_ops_cache.get( 1072 cache_key) 1073 if cached_op is not None: 1074 # This op has already been moved to the forward graph and we have it in 1075 # the cache. 1076 return cached_op 1077 1078 with self._forward_graph.as_default(): 1079 # `name` was built using name_scope stack of gradient graph and may not 1080 # be unique in the forward graph. `Graph.create_op` does not uniquify 1081 # names which are name scopes i.e. end in `/`. To ensure that the op 1082 # created gets a unique name in the forward graph we get rid of the 1083 # trailing slash. 1084 name = ops.name_from_scope_name(name) 1085 result = self._forward_graph._create_op_internal( 1086 op_type, 1087 inputs, 1088 dtypes=dtypes, 1089 input_types=input_types, 1090 name=name, 1091 attrs=attrs, 1092 op_def=op_def, 1093 compute_device=compute_device) 1094 1095 # Store the op we just moved to the forward graph so that it does 1096 # not need to be added there again. 1097 self._forward_graph._optimized_reduction_ops_cache[cache_key] = result 1098 return result 1099 1100 def _get_optimized_reduction_ops_cache_key( 1101 self, 1102 op_type, 1103 inputs, 1104 dtypes=None, # pylint: disable=redefined-outer-name 1105 input_types=None, 1106 name=None, 1107 attrs=None, 1108 op_def=None, 1109 compute_device=True): 1110 # We need all elements of CacheKey to be hashable. 1111 inputs = tuple(map(lambda t: t.ref(), inputs)) 1112 1113 if dtypes is not None: 1114 dtypes = tuple(dtypes) 1115 1116 if input_types is not None: 1117 input_types = tuple(input_types) 1118 1119 if attrs is not None: 1120 hashable_attrs = [] 1121 for attr_name, attr_value in sorted(attrs.items()): 1122 hashable_attrs.append((attr_name, attr_value.SerializeToString())) 1123 attrs = tuple(hashable_attrs) 1124 1125 if op_def is not None: 1126 op_def = op_def.SerializeToString() 1127 1128 return OptimizedReductionOpsCacheKey(op_type, inputs, dtypes, input_types, 1129 name, attrs, op_def, compute_device) 1130 1131 def _capture_helper(self, tensor, name): 1132 """Implements the capturing described in the class docstring.""" 1133 captured_tensor = self._indirect_captures.get(ops.tensor_id(tensor)) 1134 if captured_tensor is not None: 1135 return captured_tensor 1136 1137 if tensor.graph is not self._forward_graph: 1138 already_captured = self.captured(tensor) 1139 captured_tensor = super(_WhileBodyGradFuncGraph, self)._capture_helper( 1140 tensor, name) 1141 if not already_captured: 1142 # Adds the captured tensor to the list of outputs so that the input 1143 # and output signatures match. 1144 self.internal_capture_to_output[ops.tensor_id( 1145 captured_tensor)] = captured_tensor 1146 self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor 1147 return captured_tensor 1148 1149 while tensor.op.type == "Identity": 1150 # We do not accumulate the output of identity nodes so we try to capture 1151 # the input of the Identity node instead. 1152 tensor = tensor.op.inputs[0] 1153 1154 captured_tensor = self._indirect_captures.get(ops.tensor_id(tensor)) 1155 if captured_tensor is not None: 1156 return captured_tensor 1157 1158 # No need to accumulate loop invariants. Capture them directly. 1159 # The captured tensor gets resolved to the corresponding while output in 1160 # `_resolve_grad_captures`. 1161 if _is_loop_invariant(tensor, self._forward_graph.inputs, 1162 self._forward_graph.outputs): 1163 captured_tensor = super(_WhileBodyGradFuncGraph, 1164 self)._capture_helper(tensor, name) 1165 # Add to `internal_capture_to_output` so that this gets added to the list 1166 # of outputs. 1167 self.internal_capture_to_output[ops.tensor_id( 1168 captured_tensor)] = captured_tensor 1169 self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor 1170 return captured_tensor 1171 1172 # Do not accumulate Const nodes. Instead copy them directly in the backward 1173 # graph. 1174 # TODO(srbs): This just checks for `Const` nodes. Consider checking for 1175 # graph compile time consts in general. 1176 # TODO(srbs): Consider making this a loop input. 1177 if constant_op.is_constant(tensor): 1178 real_value = constant_op.constant( 1179 tensor_util.constant_value(tensor), dtype=tensor.dtype) 1180 self._indirect_captures[ops.tensor_id(tensor)] = real_value 1181 return real_value 1182 1183 # Resource tensors are not accumulated and handled specially. 1184 if tensor.dtype == dtypes.resource: 1185 return self._resource_capture_helper(tensor) 1186 1187 # Create or find an existing accumulator output for `tensor` in the forward 1188 # graph, and fetch from this accumulator in the gradient graph to get the 1189 # raw intermediate value. 1190 accumulator = _get_accumulator(tensor) 1191 if accumulator is None: 1192 # Create the initial empty tensor list. 1193 # 1194 # Note: We clear the control dependencies to avoid a cycle in case a 1195 # control tensor has an input path to an output of the forward While. 1196 # 1197 # E.g.: 1198 # x = tf.while_loop(...) 1199 # y = f(x) 1200 # with tf.control_dependencies([y]): 1201 # tf.gradients(y, x) 1202 # 1203 # Since the EmptyTensorList is fed back into the forward While, not 1204 # removing the control edge would cause a cycle. 1205 with self._forward_graph.outer_graph.as_default(): 1206 with util.clear_control_inputs(): 1207 tensor_list = list_ops.empty_tensor_list( 1208 element_dtype=tensor.dtype, 1209 element_shape=tensor.shape, 1210 max_num_elements=self._maximum_iterations, 1211 name=_build_accumulator_name(tensor)) 1212 self.extra_inputs.append(tensor_list) 1213 1214 # Push the intermediate tensor to the tensor list. This captures 1215 # `tensor_list`. 1216 with self._forward_graph.as_default(): 1217 accumulator = list_ops.tensor_list_push_back(tensor_list, tensor) 1218 # Add the modified tensor list to the list of outputs. This output will be 1219 # all the accumulated values. 1220 self._forward_graph.outputs.append(accumulator) 1221 1222 # Capture in the cond graph as well so the forward cond and body inputs 1223 # match. 1224 with self._forward_cond_graph.as_default(): 1225 self._forward_cond_graph.capture(tensor_list) 1226 1227 # Capture the accumulator tensor list in the gradient graph directly from 1228 # the forward graph -- we'll later modify this to capture the final list 1229 # output by the forward While op instead. 1230 captured_accumulator = super(_WhileBodyGradFuncGraph, self)._capture_helper( 1231 accumulator, name) 1232 1233 # Pop the intermediate value from the tensor list in the gradient graph. 1234 new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back( 1235 captured_accumulator, element_dtype=tensor.dtype) 1236 1237 self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor 1238 self.internal_capture_to_output[ops.tensor_id( 1239 captured_accumulator)] = new_tensor_list 1240 return captured_tensor 1241 1242 def _resource_capture_helper(self, tensor): 1243 """Returns the captured resource tensor. 1244 1245 Resource-type tensors are not accumulated. If a resource tensor exists in 1246 the loop body it must either be a loop input or an output of a nested While 1247 op inside the loop body which had captured the external resource. 1248 1249 Args: 1250 tensor: the external resource Tensor to be captured. 1251 1252 Returns: 1253 Tensor in this graph. 1254 """ 1255 assert tensor.dtype == dtypes.resource 1256 1257 forward_graph_input_names = [t.name for t in self._forward_graph.inputs] 1258 forward_graph_name_to_opdef = { 1259 op.name: op.node_def for op in self._forward_graph.get_operations()} 1260 index = util.resource_input_index( 1261 tensor.name, forward_graph_input_names, 1262 forward_graph_name_to_opdef, 1263 self._forward_graph._functions) 1264 1265 input_placeholder = self._forward_graph.inputs[index] 1266 tensor_in_outer_graph = self._forward_graph._while.inputs[index] 1267 1268 assert input_placeholder.dtype == dtypes.resource 1269 assert tensor_in_outer_graph.dtype == dtypes.resource 1270 # This must be a loop invariant. However, infrastructure 1271 # (e.g. tf.vectorized_map) may insert identity nodes, function calls, conds, 1272 # etc. which take and return the resource tensor unmodified; this means that 1273 # the Python objects may differ. 1274 if index != util.resource_input_index( 1275 self._forward_graph.outputs[index].name, forward_graph_input_names, 1276 forward_graph_name_to_opdef, 1277 self._forward_graph._functions): 1278 raise AssertionError( 1279 f"Resource tensors must be loop invariants {tensor_in_outer_graph}") 1280 1281 self._indirect_captures[ops.tensor_id(tensor)] = self.capture( 1282 tensor_in_outer_graph) 1283 return self._indirect_captures[ops.tensor_id(tensor)] 1284 1285 1286def _check_shapes_compat(flat_output_tensors, flat_shape_invariants, 1287 flat_input_tensors): 1288 for (t, shape, input_t) in zip(flat_output_tensors, flat_shape_invariants, 1289 flat_input_tensors): 1290 if not control_flow_ops._ShapeLessThanOrEqual(t.shape, shape): 1291 raise ValueError( 1292 f"Input tensor `{input_t.name}` enters the loop with shape {shape}, " 1293 f"but has shape {t.shape} after one iteration. To allow the shape to " 1294 "vary across iterations, use the `shape_invariants` argument of " 1295 "tf.while_loop to specify a less-specific shape.") 1296 1297 1298def _check_num_inputs_outputs(cond_graph, body_graph, num_flattened_loop_vars): 1299 """Checks the number of inputs/outputs of `cond_graph` and `body_graph`.""" 1300 assert len(cond_graph.inputs) == num_flattened_loop_vars, ( 1301 "cond_graph takes %d inputs; Expected: %d" % (len(cond_graph.inputs), 1302 num_flattened_loop_vars)) 1303 assert len(cond_graph.outputs) == 1, ( 1304 "cond_graph has %d outputs; Expected: 1" % len(cond_graph.outputs)) 1305 assert len(body_graph.inputs) == num_flattened_loop_vars, ( 1306 "body_graph takes %d inputs; Expected: %d" % (len(body_graph.inputs), 1307 num_flattened_loop_vars)) 1308 assert len(body_graph.outputs) == num_flattened_loop_vars, ( 1309 "body_graph has %d outputs; Expected: %d" % (len(body_graph.outputs), 1310 num_flattened_loop_vars)) 1311 1312 1313def _check_inputs_outputs_types_match(body_graph, flattened_loop_vars): 1314 for inp, out, loop_var in zip(body_graph.inputs, body_graph.outputs, 1315 flattened_loop_vars): 1316 if inp.dtype != out.dtype: 1317 raise TypeError( 1318 f"Loop var {loop_var.name} enters the loop with type {inp.dtype} " 1319 f"but has type {out.dtype} after 1 iteration. {loop_var.name} type " 1320 "should remain constant.") 1321 1322 1323def _build_cond_placeholders_name_prefix(cond_graph): 1324 return cond_graph.unique_name(cond_graph.name + "___redundant_placeholder") 1325 1326 1327def _duplicate_body_captures_in_cond(cond_graph, body_graph_captures): 1328 """Creates placeholders for body captures in cond_graph. 1329 1330 This is needed to match signatures of cond and body graphs. 1331 1332 Args: 1333 cond_graph: cond branch graph 1334 body_graph_captures: Tensors which were captured when building the 1335 `body_graph`. 1336 """ 1337 types = [t.dtype.as_datatype_enum for t in body_graph_captures] 1338 # TODO(srbs): Providing a unique prefix does not ensure that there is no 1339 # conflict between the placeholder names and existing nodes in the graph. 1340 # However passing a list of strings may not be performant. 1341 # Ideally we should move `Graph.unique_name` to C++ or make 1342 # `Graph._names_in_use` a trie so that we can find a unique prefix. 1343 # TODO(b/143286622): This should not be required once captures are separated 1344 # from regular loop vars. 1345 with cond_graph._c_graph.get() as c_graph: 1346 placeholders = c_api.TF_CreatePlaceholders( 1347 c_graph, types, 1348 compat.as_str(_build_cond_placeholders_name_prefix(cond_graph))) 1349 placeholder_ops = [ 1350 _OperationWithOutputs(ph.oper, cond_graph) 1351 for ph in placeholders 1352 ] 1353 1354 tensors = [] 1355 for op, ph, dtype in zip(placeholder_ops, placeholders, types): 1356 tensor = ops.Tensor._create_with_tf_output(op, 0, dtype, ph) 1357 op._outputs = [tensor] 1358 tensors.append(tensor) 1359 1360 # Update `cond_graph._captures` and `cond_graph.inputs` to contain the 1361 # newly created placeholders. 1362 tuples = zip(body_graph_captures, tensors) 1363 keys = [id(t) for t in body_graph_captures] 1364 cond_graph._captures.update(zip(keys, tuples)) 1365 cond_graph.inputs.extend(tensors) 1366 1367 1368def _copy_handle_data(src_tensors, tgt_tensors): 1369 for src_t, tgt_t in zip(src_tensors, tgt_tensors): 1370 handle_data_util.copy_handle_data(src_t, tgt_t) 1371 1372 1373def _graph_name(graph): 1374 if isinstance(graph, func_graph_module.FuncGraph): 1375 return graph.name 1376 return "Base" 1377 1378 1379def _pack_sequence_as(loop_vars_signature, flat_orig_loop_vars, loop_vars): 1380 """Like `nest.pack_sequence_as` but also replaces flows with TensorArrays.""" 1381 1382 def flow_to_tensor_array(flow, ta): # pylint: disable=missing-docstring 1383 return (tensor_array_ops.build_ta_with_new_flow(ta, flow) if isinstance( # pylint: disable=g-long-ternary 1384 ta, tensor_array_ops.TensorArray) else flow) 1385 1386 flattened_loop_vars = [ 1387 flow_to_tensor_array(*z) 1388 for z in zip(nest.flatten(loop_vars, expand_composites=True), 1389 flat_orig_loop_vars) 1390 ] 1391 return nest.pack_sequence_as(loop_vars_signature, flattened_loop_vars, 1392 expand_composites=True) 1393 1394 1395def _tensor_array_to_flow(loop_vars): 1396 1397 def f(maybe_ta): 1398 if isinstance(maybe_ta, tensor_array_ops.TensorArray): 1399 return maybe_ta.flow 1400 return maybe_ta 1401 1402 return nest.map_structure(f, loop_vars, expand_composites=True) 1403 1404 1405def _build_maximum_iterations_loop_var(maximum_iterations): 1406 if maximum_iterations is None: 1407 # Default value for max_num_elements to EmptyTensorList meaning that the 1408 # list size is unbounded. 1409 maximum_iterations = -1 1410 # EmptyTensorList expects `max_num_elements` to be of type int32. 1411 return ops.convert_to_tensor( 1412 maximum_iterations, dtype=dtypes.int32, name="maximum_iterations") 1413 1414 1415def _build_accumulator_name(tensor): 1416 # Tensor name may be of the form "pow/y:0". Name scope does not allow ":". 1417 return "{}/accumulator".format(tensor.name).replace(":", "_") 1418 1419 1420def _is_loop_invariant(tensor, inputs, outputs): 1421 return (any(tensor is t for t in inputs) and 1422 any(tensor is t for t in outputs)) 1423 1424 1425class _OperationWithOutputs(ops.Operation): 1426 """Operation with pre-built `TF_Output`s. 1427 1428 The C API for creating the extra placeholders for the cond graph returns 1429 SWIG wrapped TF_Output* pointers which we can use directly for 1430 `Operation.outputs`. The default constructor for `Operation` does not provide 1431 a way of specifying pre-built output tensors and always creates them. This is 1432 a performance overhead. It is not clear if adding that feature to the 1433 `Operation` API would be generally useful so for now we just have our own 1434 lightweight `Operation` implementation. Note that this does not extract a 1435 stacktrace as well since we don't expect this operation to be used. 1436 1437 TODO(b/143286622): This should not be required once captures are separated 1438 from regular loop vars. 1439 """ 1440 1441 def __init__(self, c_op, g): 1442 self._c_op = c_op 1443 self._graph = g 1444 self._outputs = None # Initialized by _duplicate_body_captures_in_cond(). 1445 self._id_value = g._add_op(self, self.name) 1446 self._is_stateful = False 1447 1448 1449def _set_read_only_resource_inputs_attr(op, branch_graphs): 1450 """Sets the list of resource inputs which are read-only. 1451 1452 This is used by AutomaticControlDependencies. 1453 1454 Args: 1455 op: While Operation. 1456 branch_graphs: List of branch FuncGraphs. 1457 """ 1458 read_only_indices = set(range(len(op.inputs))) 1459 for branch_graph in branch_graphs: 1460 if not read_only_indices: 1461 break 1462 branch_read_only_indices = acd.get_read_only_resource_input_indices_graph( 1463 branch_graph) 1464 read_only_indices = read_only_indices.intersection(branch_read_only_indices) 1465 1466 ops.set_int_list_attr(op, acd.READ_ONLY_RESOURCE_INPUTS_ATTR, 1467 sorted(read_only_indices)) 1468 1469# pylint: enable=protected-access 1470