1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Functional operations.""" 16 17from tensorflow.core.framework import attr_value_pb2 18from tensorflow.python.eager import context 19from tensorflow.python.framework import auto_control_deps_utils as acd 20from tensorflow.python.framework import constant_op 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import function 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import tensor_shape 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import control_flow_ops 27from tensorflow.python.ops import gen_functional_ops 28from tensorflow.python.ops import math_ops 29from tensorflow.python.ops import tensor_array_ops 30from tensorflow.python.ops import variable_scope as vs 31# pylint: disable=unused-import 32from tensorflow.python.ops.gen_functional_ops import remote_call 33# pylint: enable=unused-import 34from tensorflow.python.ops.gen_functional_ops import symbolic_gradient 35from tensorflow.python.util import compat 36from tensorflow.python.util import deprecation 37from tensorflow.python.util import dispatch 38from tensorflow.python.util import function_utils 39from tensorflow.python.util import nest 40from tensorflow.python.util.tf_export import tf_export 41 42 43# TODO(yuanbyu, mrry): Handle stride to support sliding windows. 44@tf_export(v1=["foldl"]) 45@dispatch.add_dispatch_support 46def foldl(fn, 47 elems, 48 initializer=None, 49 parallel_iterations=10, 50 back_prop=True, 51 swap_memory=False, 52 name=None): 53 """foldl on the list of tensors unpacked from `elems` on dimension 0. 54 55 This foldl operator repeatedly applies the callable `fn` to a sequence 56 of elements from first to last. The elements are made of the tensors 57 unpacked from `elems` on dimension 0. The callable fn takes two tensors as 58 arguments. The first argument is the accumulated value computed from the 59 preceding invocation of fn, and the second is the value at the current 60 position of `elems`. If `initializer` is None, `elems` must contain at least 61 one element, and its first element is used as the initializer. 62 63 Suppose that `elems` is unpacked into `values`, a list of tensors. The shape 64 of the result tensor is fn(initializer, values[0]).shape`. 65 66 This method also allows multi-arity `elems` and output of `fn`. If `elems` 67 is a (possibly nested) list or tuple of tensors, then each of these tensors 68 must have a matching first (unpack) dimension. The signature of `fn` may 69 match the structure of `elems`. That is, if `elems` is 70 `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is: 71 `fn = lambda (t1, [t2, t3, [t4, t5]]):`. 72 73 Args: 74 fn: The callable to be performed. 75 elems: A tensor or (possibly nested) sequence of tensors, each of which will 76 be unpacked along their first dimension. The nested sequence of the 77 resulting slices will be the first argument to `fn`. 78 initializer: (optional) A tensor or (possibly nested) sequence of tensors, 79 as the initial value for the accumulator. 80 parallel_iterations: (optional) The number of iterations allowed to run in 81 parallel. 82 back_prop: (optional) True enables support for back propagation. 83 swap_memory: (optional) True enables GPU-CPU memory swapping. 84 name: (optional) Name prefix for the returned tensors. 85 86 Returns: 87 A tensor or (possibly nested) sequence of tensors, resulting from applying 88 `fn` consecutively to the list of tensors unpacked from `elems`, from first 89 to last. 90 91 Raises: 92 TypeError: if `fn` is not callable. 93 94 Example: 95 ```python 96 elems = tf.constant([1, 2, 3, 4, 5, 6]) 97 sum = foldl(lambda a, x: a + x, elems) 98 # sum == 21 99 ``` 100 """ 101 if not callable(fn): 102 raise TypeError( 103 f"{fn.__name__} is not callable. Please provide a callable function.") 104 105 def create_ta(elem): 106 return tensor_array_ops.TensorArray( 107 dtype=elem.dtype, size=n, dynamic_size=False, 108 infer_shape=True).unstack(elem) 109 110 in_graph_mode = not context.executing_eagerly() 111 with ops.name_scope(name, "foldl", [elems]): 112 # TODO(akshayka): Remove the in_graph_mode check once caching devices are 113 # supported in Eager 114 if in_graph_mode: 115 # Any get_variable calls in fn will cache the first call locally 116 # and not issue repeated network I/O requests for each iteration. 117 varscope = vs.get_variable_scope() 118 varscope_caching_device_was_none = False 119 if varscope.caching_device is None: 120 # TODO(ebrevdo): Change to using colocate_with here and in other 121 # methods. 122 varscope.set_caching_device(lambda op: op.device) 123 varscope_caching_device_was_none = True 124 125 # Convert elems to tensor array. n may be known statically. 126 elems_flat = [ 127 ops.convert_to_tensor(elem, name="elem") for elem in nest.flatten(elems) 128 ] 129 n = ( 130 tensor_shape.dimension_value(elems_flat[0].shape[0]) or 131 array_ops.shape(elems_flat[0])[0]) 132 133 elems_ta = nest.map_structure(create_ta, elems) 134 135 if initializer is None: 136 a = nest.map_structure(lambda elem: elem.read(0), elems_ta) 137 i = constant_op.constant(1) 138 else: 139 a = initializer 140 i = constant_op.constant(0) 141 142 def compute(i, a): 143 elem_i = nest.map_structure(lambda elem: elem.read(i), elems_ta) 144 a = fn(a, elem_i) 145 return [i + 1, a] 146 147 _, r_a = control_flow_ops.while_loop( 148 lambda i, a: i < n, 149 compute, [i, a], 150 parallel_iterations=parallel_iterations, 151 back_prop=back_prop, 152 swap_memory=swap_memory, 153 maximum_iterations=n) 154 155 # TODO(akshayka): Remove the in_graph_mode check once caching devices are 156 # supported in Eager 157 if in_graph_mode and varscope_caching_device_was_none: 158 varscope.set_caching_device(None) 159 160 return r_a 161 162 163@tf_export("foldl", v1=[]) 164@dispatch.add_dispatch_support 165@deprecation.deprecated_arg_values( 166 None, 167 """back_prop=False is deprecated. Consider using tf.stop_gradient instead. 168Instead of: 169results = tf.foldl(fn, elems, back_prop=False) 170Use: 171results = tf.nest.map_structure(tf.stop_gradient, tf.foldl(fn, elems))""", 172 warn_once=True, 173 back_prop=False) 174def foldl_v2(fn, 175 elems, 176 initializer=None, 177 parallel_iterations=10, 178 back_prop=True, 179 swap_memory=False, 180 name=None): 181 """foldl on the list of tensors unpacked from `elems` on dimension 0. 182 183 This foldl operator repeatedly applies the callable `fn` to a sequence 184 of elements from first to last. The elements are made of the tensors 185 unpacked from `elems` on dimension 0. The callable fn takes two tensors as 186 arguments. The first argument is the accumulated value computed from the 187 preceding invocation of fn, and the second is the value at the current 188 position of `elems`. If `initializer` is None, `elems` must contain at least 189 one element, and its first element is used as the initializer. 190 191 Suppose that `elems` is unpacked into `values`, a list of tensors. The shape 192 of the result tensor is fn(initializer, values[0]).shape`. 193 194 This method also allows multi-arity `elems` and output of `fn`. If `elems` 195 is a (possibly nested) list or tuple of tensors, then each of these tensors 196 must have a matching first (unpack) dimension. The signature of `fn` may 197 match the structure of `elems`. That is, if `elems` is 198 `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is: 199 `fn = lambda (t1, [t2, t3, [t4, t5]]):`. 200 201 Args: 202 fn: The callable to be performed. 203 elems: A tensor or (possibly nested) sequence of tensors, each of which will 204 be unpacked along their first dimension. The nested sequence of the 205 resulting slices will be the first argument to `fn`. 206 initializer: (optional) A tensor or (possibly nested) sequence of tensors, 207 as the initial value for the accumulator. 208 parallel_iterations: (optional) The number of iterations allowed to run in 209 parallel. 210 back_prop: (optional) Deprecated. False disables support for back 211 propagation. Prefer using `tf.stop_gradient` instead. 212 swap_memory: (optional) True enables GPU-CPU memory swapping. 213 name: (optional) Name prefix for the returned tensors. 214 215 Returns: 216 A tensor or (possibly nested) sequence of tensors, resulting from applying 217 `fn` consecutively to the list of tensors unpacked from `elems`, from first 218 to last. 219 220 Raises: 221 TypeError: if `fn` is not callable. 222 223 Example: 224 ```python 225 elems = tf.constant([1, 2, 3, 4, 5, 6]) 226 sum = tf.foldl(lambda a, x: a + x, elems) 227 # sum == 21 228 ``` 229 """ 230 return foldl( 231 fn=fn, 232 elems=elems, 233 initializer=initializer, 234 parallel_iterations=parallel_iterations, 235 back_prop=back_prop, 236 swap_memory=swap_memory, 237 name=name) 238 239 240@tf_export(v1=["foldr"]) 241@dispatch.add_dispatch_support 242def foldr(fn, 243 elems, 244 initializer=None, 245 parallel_iterations=10, 246 back_prop=True, 247 swap_memory=False, 248 name=None): 249 """foldr on the list of tensors unpacked from `elems` on dimension 0. 250 251 This foldr operator repeatedly applies the callable `fn` to a sequence 252 of elements from last to first. The elements are made of the tensors 253 unpacked from `elems`. The callable fn takes two tensors as arguments. 254 The first argument is the accumulated value computed from the preceding 255 invocation of fn, and the second is the value at the current position of 256 `elems`. If `initializer` is None, `elems` must contain at least one element, 257 and its first element is used as the initializer. 258 259 Suppose that `elems` is unpacked into `values`, a list of tensors. The shape 260 of the result tensor is `fn(initializer, values[0]).shape`. 261 262 This method also allows multi-arity `elems` and output of `fn`. If `elems` 263 is a (possibly nested) list or tuple of tensors, then each of these tensors 264 must have a matching first (unpack) dimension. The signature of `fn` may 265 match the structure of `elems`. That is, if `elems` is 266 `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is: 267 `fn = lambda (t1, [t2, t3, [t4, t5]]):`. 268 269 Args: 270 fn: The callable to be performed. 271 elems: A tensor or (possibly nested) sequence of tensors, each of which will 272 be unpacked along their first dimension. The nested sequence of the 273 resulting slices will be the first argument to `fn`. 274 initializer: (optional) A tensor or (possibly nested) sequence of tensors, 275 as the initial value for the accumulator. 276 parallel_iterations: (optional) The number of iterations allowed to run in 277 parallel. 278 back_prop: (optional) True enables support for back propagation. 279 swap_memory: (optional) True enables GPU-CPU memory swapping. 280 name: (optional) Name prefix for the returned tensors. 281 282 Returns: 283 A tensor or (possibly nested) sequence of tensors, resulting from applying 284 `fn` consecutively to the list of tensors unpacked from `elems`, from last 285 to first. 286 287 Raises: 288 TypeError: if `fn` is not callable. 289 290 Example: 291 ```python 292 elems = [1, 2, 3, 4, 5, 6] 293 sum = foldr(lambda a, x: a + x, elems) 294 # sum == 21 295 ``` 296 """ 297 if not callable(fn): 298 raise TypeError( 299 f"{fn.__name__} is not callable. Please provide a callable function.") 300 301 def create_ta(elem): 302 return tensor_array_ops.TensorArray( 303 dtype=elem.dtype, size=n, dynamic_size=False, 304 infer_shape=True).unstack(elem) 305 306 in_graph_mode = not context.executing_eagerly() 307 with ops.name_scope(name, "foldr", [elems]): 308 # TODO(akshayka): Remove the in_graph_mode check once caching devices are 309 # supported in Eager 310 if in_graph_mode: 311 # Any get_variable calls in fn will cache the first call locally and not 312 # issue repeated network I/O requests for each iteration. 313 varscope = vs.get_variable_scope() 314 varscope_caching_device_was_none = False 315 if varscope.caching_device is None: 316 # TODO(ebrevdo): Change to using colocate_with here and in other 317 # methods. 318 varscope.set_caching_device(lambda op: op.device) 319 varscope_caching_device_was_none = True 320 321 # Convert elems to tensor array. n may be known statically. 322 elems_flat = [ 323 ops.convert_to_tensor(elem, name="elem") for elem in nest.flatten(elems) 324 ] 325 n = ( 326 tensor_shape.dimension_value(elems_flat[0].shape[0]) or 327 array_ops.shape(elems_flat[0])[0]) 328 329 elems_ta = nest.map_structure(create_ta, elems) 330 331 if initializer is None: 332 i = n - 1 333 a = nest.map_structure(lambda elem: elem.read(i), elems_ta) 334 else: 335 i = n 336 a = initializer 337 338 def compute(i, a): 339 i -= 1 340 elem = nest.map_structure(lambda elem: elem.read(i), elems_ta) 341 a_out = fn(a, elem) 342 return [i, a_out] 343 344 _, r_a = control_flow_ops.while_loop( 345 lambda i, a: i > 0, 346 compute, [i, a], 347 parallel_iterations=parallel_iterations, 348 back_prop=back_prop, 349 swap_memory=swap_memory, 350 maximum_iterations=n) 351 352 # TODO(akshayka): Remove the in_graph_mode check once caching devices are 353 # supported in Eager 354 if in_graph_mode and varscope_caching_device_was_none: 355 varscope.set_caching_device(None) 356 357 return r_a 358 359 360@tf_export("foldr", v1=[]) 361@dispatch.add_dispatch_support 362@deprecation.deprecated_arg_values( 363 None, 364 """back_prop=False is deprecated. Consider using tf.stop_gradient instead. 365Instead of: 366results = tf.foldr(fn, elems, back_prop=False) 367Use: 368results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))""", 369 warn_once=True, 370 back_prop=False) 371def foldr_v2(fn, 372 elems, 373 initializer=None, 374 parallel_iterations=10, 375 back_prop=True, 376 swap_memory=False, 377 name=None): 378 """foldr on the list of tensors unpacked from `elems` on dimension 0. 379 380 This foldr operator repeatedly applies the callable `fn` to a sequence 381 of elements from last to first. The elements are made of the tensors 382 unpacked from `elems`. The callable fn takes two tensors as arguments. 383 The first argument is the accumulated value computed from the preceding 384 invocation of fn, and the second is the value at the current position of 385 `elems`. If `initializer` is None, `elems` must contain at least one element, 386 and its first element is used as the initializer. 387 388 Suppose that `elems` is unpacked into `values`, a list of tensors. The shape 389 of the result tensor is `fn(initializer, values[0]).shape`. 390 391 This method also allows multi-arity `elems` and output of `fn`. If `elems` 392 is a (possibly nested) list or tuple of tensors, then each of these tensors 393 must have a matching first (unpack) dimension. The signature of `fn` may 394 match the structure of `elems`. That is, if `elems` is 395 `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is: 396 `fn = lambda (t1, [t2, t3, [t4, t5]]):`. 397 398 Args: 399 fn: The callable to be performed. 400 elems: A tensor or (possibly nested) sequence of tensors, each of which will 401 be unpacked along their first dimension. The nested sequence of the 402 resulting slices will be the first argument to `fn`. 403 initializer: (optional) A tensor or (possibly nested) sequence of tensors, 404 as the initial value for the accumulator. 405 parallel_iterations: (optional) The number of iterations allowed to run in 406 parallel. 407 back_prop: (optional) Deprecated. False disables support for back 408 propagation. Prefer using `tf.stop_gradient` instead. 409 swap_memory: (optional) True enables GPU-CPU memory swapping. 410 name: (optional) Name prefix for the returned tensors. 411 412 Returns: 413 A tensor or (possibly nested) sequence of tensors, resulting from applying 414 `fn` consecutively to the list of tensors unpacked from `elems`, from last 415 to first. 416 417 Raises: 418 TypeError: if `fn` is not callable. 419 420 Example: 421 ```python 422 elems = [1, 2, 3, 4, 5, 6] 423 sum = tf.foldr(lambda a, x: a + x, elems) 424 # sum == 21 425 ``` 426 """ 427 return foldr( 428 fn=fn, 429 elems=elems, 430 initializer=initializer, 431 parallel_iterations=parallel_iterations, 432 back_prop=back_prop, 433 swap_memory=swap_memory, 434 name=name) 435 436 437@tf_export(v1=["scan"]) 438@dispatch.add_dispatch_support 439def scan(fn, 440 elems, 441 initializer=None, 442 parallel_iterations=10, 443 back_prop=True, 444 swap_memory=False, 445 infer_shape=True, 446 reverse=False, 447 name=None): 448 """scan on the list of tensors unpacked from `elems` on dimension 0. 449 450 See also `tf.map_fn`. 451 452 The simplest version of `scan` repeatedly applies the callable `fn` to a 453 sequence of elements from first to last. The elements are made of the tensors 454 unpacked from `elems` on dimension 0. The callable fn takes two tensors as 455 arguments. The first argument is the accumulated value computed from the 456 preceding invocation of fn, and the second is the value at the current 457 position of `elems`. If `initializer` is None, `elems` must contain at least 458 one element, and its first element is used as the initializer. 459 460 Suppose that `elems` is unpacked into `values`, a list of tensors. The shape 461 of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`. 462 If reverse=True, it's fn(initializer, values[-1]).shape. 463 464 This method also allows multi-arity `elems` and accumulator. If `elems` 465 is a (possibly nested) list or tuple of tensors, then each of these tensors 466 must have a matching first (unpack) dimension. The second argument of 467 `fn` must match the structure of `elems`. 468 469 If no `initializer` is provided, the output structure and dtypes of `fn` 470 are assumed to be the same as its input; and in this case, the first 471 argument of `fn` must match the structure of `elems`. 472 473 If an `initializer` is provided, then the output of `fn` must have the same 474 structure as `initializer`; and the first argument of `fn` must match 475 this structure. 476 477 For example, if `elems` is `(t1, [t2, t3])` and `initializer` is 478 `[i1, i2]` then an appropriate signature for `fn` in `python2` is: 479 `fn = lambda (acc_p1, acc_p2), (t1, [t2, t3]):` and `fn` must return a list, 480 `[acc_n1, acc_n2]`. An alternative correct signature for `fn`, and the 481 one that works in `python3`, is: 482 `fn = lambda a, t:`, where `a` and `t` correspond to the input tuples. 483 484 Args: 485 fn: The callable to be performed. It accepts two arguments. The first will 486 have the same structure as `initializer` if one is provided, otherwise it 487 will have the same structure as `elems`. The second will have the same 488 (possibly nested) structure as `elems`. Its output must have the same 489 structure as `initializer` if one is provided, otherwise it must have the 490 same structure as `elems`. 491 elems: A tensor or (possibly nested) sequence of tensors, each of which will 492 be unpacked along their first dimension. The nested sequence of the 493 resulting slices will be the first argument to `fn`. 494 initializer: (optional) A tensor or (possibly nested) sequence of tensors, 495 initial value for the accumulator, and the expected output type of `fn`. 496 parallel_iterations: (optional) The number of iterations allowed to run in 497 parallel. 498 back_prop: (optional) True enables support for back propagation. 499 swap_memory: (optional) True enables GPU-CPU memory swapping. 500 infer_shape: (optional) False disables tests for consistent output shapes. 501 reverse: (optional) True scans the tensor last to first (instead of first to 502 last). 503 name: (optional) Name prefix for the returned tensors. 504 505 Returns: 506 A tensor or (possibly nested) sequence of tensors. Each tensor packs the 507 results of applying `fn` to tensors unpacked from `elems` along the first 508 dimension, and the previous accumulator value(s), from first to last (or 509 last to first, if `reverse=True`). 510 511 Raises: 512 TypeError: if `fn` is not callable or the structure of the output of 513 `fn` and `initializer` do not match. 514 ValueError: if the lengths of the output of `fn` and `initializer` 515 do not match. 516 517 Examples: 518 ```python 519 elems = np.array([1, 2, 3, 4, 5, 6]) 520 sum = scan(lambda a, x: a + x, elems) 521 # sum == [1, 3, 6, 10, 15, 21] 522 sum = scan(lambda a, x: a + x, elems, reverse=True) 523 # sum == [21, 20, 18, 15, 11, 6] 524 ``` 525 526 ```python 527 elems = np.array([1, 2, 3, 4, 5, 6]) 528 initializer = np.array(0) 529 sum_one = scan( 530 lambda a, x: x[0] - x[1] + a, (elems + 1, elems), initializer) 531 # sum_one == [1, 2, 3, 4, 5, 6] 532 ``` 533 534 ```python 535 elems = np.array([1, 0, 0, 0, 0, 0]) 536 initializer = (np.array(0), np.array(1)) 537 fibonaccis = scan(lambda a, _: (a[1], a[0] + a[1]), elems, initializer) 538 # fibonaccis == ([1, 1, 2, 3, 5, 8], [1, 2, 3, 5, 8, 13]) 539 ``` 540 """ 541 if not callable(fn): 542 raise TypeError( 543 f"{fn.__name__} is not callable. Please provide a callable function.") 544 545 input_is_sequence = nest.is_nested(elems) 546 input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x] 547 548 def input_pack(x): 549 return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0] 550 551 if initializer is None: 552 output_is_sequence = input_is_sequence 553 output_flatten = input_flatten 554 output_pack = input_pack 555 else: 556 output_is_sequence = nest.is_nested(initializer) 557 output_flatten = lambda x: nest.flatten(x) if output_is_sequence else [x] 558 559 def output_pack(x): 560 return (nest.pack_sequence_as(initializer, x) 561 if output_is_sequence else x[0]) 562 563 elems_flat = input_flatten(elems) 564 565 in_graph_mode = not context.executing_eagerly() 566 with ops.name_scope(name, "scan", elems_flat): 567 # TODO(akshayka): Remove the in_graph_mode check once caching devices are 568 # supported in Eager 569 if in_graph_mode: 570 # Any get_variable calls in fn will cache the first call locally 571 # and not issue repeated network I/O requests for each iteration. 572 varscope = vs.get_variable_scope() 573 varscope_caching_device_was_none = False 574 if varscope.caching_device is None: 575 # TODO(ebrevdo): Change to using colocate_with here and in other 576 # methods. 577 varscope.set_caching_device(lambda op: op.device) 578 varscope_caching_device_was_none = True 579 580 # Convert elems to tensor array. 581 elems_flat = [ 582 ops.convert_to_tensor(elem, name="elem") for elem in elems_flat 583 ] 584 585 # Convert elems to tensor array. n may be known statically. 586 n = tensor_shape.dimension_value(elems_flat[0].shape[0]) 587 if n is None: 588 n = array_ops.shape(elems_flat[0])[0] 589 590 # TensorArrays are always flat 591 elems_ta = [ 592 tensor_array_ops.TensorArray( 593 dtype=elem.dtype, 594 size=n, 595 dynamic_size=False, 596 element_shape=elem.shape[1:], 597 infer_shape=True) for elem in elems_flat 598 ] 599 # Unpack elements 600 elems_ta = [ 601 elem_ta.unstack(elem) for elem_ta, elem in zip(elems_ta, elems_flat) 602 ] 603 604 if initializer is None: 605 a_flat = [elem.read(n - 1 if reverse else 0) for elem in elems_ta] 606 i = 1 607 else: 608 initializer_flat = output_flatten(initializer) 609 a_flat = [ops.convert_to_tensor(init) for init in initializer_flat] 610 i = 0 611 612 # Create a tensor array to store the intermediate values. 613 accs_ta = [ 614 tensor_array_ops.TensorArray( 615 dtype=init.dtype, 616 size=n, 617 element_shape=init.shape if infer_shape else None, 618 dynamic_size=False, 619 infer_shape=infer_shape) for init in a_flat 620 ] 621 622 if initializer is None: 623 accs_ta = [ 624 acc_ta.write(n - 1 if reverse else 0, a) 625 for (acc_ta, a) in zip(accs_ta, a_flat) 626 ] 627 628 def compute(i, a_flat, tas): 629 """The loop body of scan. 630 631 Args: 632 i: the loop counter. 633 a_flat: the accumulator value(s), flattened. 634 tas: the output accumulator TensorArray(s), flattened. 635 636 Returns: 637 [i + 1, a_flat, tas]: the updated counter + new accumulator values + 638 updated TensorArrays 639 640 Raises: 641 TypeError: if initializer and fn() output structure do not match 642 ValueType: if initializer and fn() output lengths do not match 643 """ 644 packed_elems = input_pack([elem_ta.read(i) for elem_ta in elems_ta]) 645 packed_a = output_pack(a_flat) 646 a_out = fn(packed_a, packed_elems) 647 nest.assert_same_structure(elems if initializer is None else initializer, 648 a_out) 649 flat_a_out = output_flatten(a_out) 650 tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_a_out)] 651 if reverse: 652 next_i = i - 1 653 else: 654 next_i = i + 1 655 return (next_i, flat_a_out, tas) 656 657 if reverse: 658 initial_i = n - 1 - i 659 condition = lambda i, _1, _2: i >= 0 660 else: 661 initial_i = i 662 condition = lambda i, _1, _2: i < n 663 _, _, r_a = control_flow_ops.while_loop( 664 condition, 665 compute, (initial_i, a_flat, accs_ta), 666 parallel_iterations=parallel_iterations, 667 back_prop=back_prop, 668 swap_memory=swap_memory, 669 maximum_iterations=n) 670 671 results_flat = [r.stack() for r in r_a] 672 673 n_static = tensor_shape.Dimension( 674 tensor_shape.dimension_value( 675 elems_flat[0].get_shape().with_rank_at_least(1)[0])) 676 for elem in elems_flat[1:]: 677 n_static.assert_is_compatible_with( 678 tensor_shape.Dimension( 679 tensor_shape.dimension_value( 680 elem.get_shape().with_rank_at_least(1)[0]))) 681 for r in results_flat: 682 r.set_shape( 683 tensor_shape.TensorShape(n_static).concatenate(r.get_shape()[1:])) 684 685 # TODO(akshayka): Remove the in_graph_mode check once caching devices are 686 # supported in Eager 687 if in_graph_mode and varscope_caching_device_was_none: 688 varscope.set_caching_device(None) 689 690 return output_pack(results_flat) 691 692 693@tf_export("scan", v1=[]) 694@dispatch.add_dispatch_support 695@deprecation.deprecated_arg_values( 696 None, 697 """back_prop=False is deprecated. Consider using tf.stop_gradient instead. 698Instead of: 699results = tf.scan(fn, elems, back_prop=False) 700Use: 701results = tf.nest.map_structure(tf.stop_gradient, tf.scan(fn, elems))""", 702 warn_once=True, 703 back_prop=False) 704def scan_v2(fn, 705 elems, 706 initializer=None, 707 parallel_iterations=10, 708 back_prop=True, 709 swap_memory=False, 710 infer_shape=True, 711 reverse=False, 712 name=None): 713 """scan on the list of tensors unpacked from `elems` on dimension 0. 714 715 The simplest version of `scan` repeatedly applies the callable `fn` to a 716 sequence of elements from first to last. The elements are made of the tensors 717 unpacked from `elems` on dimension 0. The callable fn takes two tensors as 718 arguments. The first argument is the accumulated value computed from the 719 preceding invocation of fn, and the second is the value at the current 720 position of `elems`. If `initializer` is None, `elems` must contain at least 721 one element, and its first element is used as the initializer. 722 723 Suppose that `elems` is unpacked into `values`, a list of tensors. The shape 724 of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`. 725 If reverse=True, it's fn(initializer, values[-1]).shape. 726 727 This method also allows multi-arity `elems` and accumulator. If `elems` 728 is a (possibly nested) list or tuple of tensors, then each of these tensors 729 must have a matching first (unpack) dimension. The second argument of 730 `fn` must match the structure of `elems`. 731 732 If no `initializer` is provided, the output structure and dtypes of `fn` 733 are assumed to be the same as its input; and in this case, the first 734 argument of `fn` must match the structure of `elems`. 735 736 If an `initializer` is provided, then the output of `fn` must have the same 737 structure as `initializer`; and the first argument of `fn` must match 738 this structure. 739 740 For example, if `elems` is `(t1, [t2, t3])` and `initializer` is 741 `[i1, i2]` then an appropriate signature for `fn` in `python2` is: 742 `fn = lambda (acc_p1, acc_p2), (t1, [t2, t3]):` and `fn` must return a list, 743 `[acc_n1, acc_n2]`. An alternative correct signature for `fn`, and the 744 one that works in `python3`, is: 745 `fn = lambda a, t:`, where `a` and `t` correspond to the input tuples. 746 747 Args: 748 fn: The callable to be performed. It accepts two arguments. The first will 749 have the same structure as `initializer` if one is provided, otherwise it 750 will have the same structure as `elems`. The second will have the same 751 (possibly nested) structure as `elems`. Its output must have the same 752 structure as `initializer` if one is provided, otherwise it must have the 753 same structure as `elems`. 754 elems: A tensor or (possibly nested) sequence of tensors, each of which will 755 be unpacked along their first dimension. The nested sequence of the 756 resulting slices will be the first argument to `fn`. 757 initializer: (optional) A tensor or (possibly nested) sequence of tensors, 758 initial value for the accumulator, and the expected output type of `fn`. 759 parallel_iterations: (optional) The number of iterations allowed to run in 760 parallel. 761 back_prop: (optional) Deprecated. False disables support for back 762 propagation. Prefer using `tf.stop_gradient` instead. 763 swap_memory: (optional) True enables GPU-CPU memory swapping. 764 infer_shape: (optional) False disables tests for consistent output shapes. 765 reverse: (optional) True scans the tensor last to first (instead of first to 766 last). 767 name: (optional) Name prefix for the returned tensors. 768 769 Returns: 770 A tensor or (possibly nested) sequence of tensors. Each tensor packs the 771 results of applying `fn` to tensors unpacked from `elems` along the first 772 dimension, and the previous accumulator value(s), from first to last (or 773 last to first, if `reverse=True`). 774 775 Raises: 776 TypeError: if `fn` is not callable or the structure of the output of 777 `fn` and `initializer` do not match. 778 ValueError: if the lengths of the output of `fn` and `initializer` 779 do not match. 780 781 Examples: 782 ```python 783 elems = np.array([1, 2, 3, 4, 5, 6]) 784 sum = scan(lambda a, x: a + x, elems) 785 # sum == [1, 3, 6, 10, 15, 21] 786 sum = scan(lambda a, x: a + x, elems, reverse=True) 787 # sum == [21, 20, 18, 15, 11, 6] 788 ``` 789 790 ```python 791 elems = np.array([1, 2, 3, 4, 5, 6]) 792 initializer = np.array(0) 793 sum_one = scan( 794 lambda a, x: x[0] - x[1] + a, (elems + 1, elems), initializer) 795 # sum_one == [1, 2, 3, 4, 5, 6] 796 ``` 797 798 ```python 799 elems = np.array([1, 0, 0, 0, 0, 0]) 800 initializer = (np.array(0), np.array(1)) 801 fibonaccis = scan(lambda a, _: (a[1], a[0] + a[1]), elems, initializer) 802 # fibonaccis == ([1, 1, 2, 3, 5, 8], [1, 2, 3, 5, 8, 13]) 803 ``` 804 """ 805 return scan( 806 fn=fn, 807 elems=elems, 808 initializer=initializer, 809 parallel_iterations=parallel_iterations, 810 back_prop=back_prop, 811 swap_memory=swap_memory, 812 infer_shape=infer_shape, 813 reverse=reverse, 814 name=name) 815 816 817# pylint: disable=invalid-name 818def If(cond, inputs, then_branch, else_branch, name=None): 819 r"""output = Cond(inputs) ? 820 821 then_branch(inputs) : else_branch(inputs). 822 823 Args: 824 cond: A `Tensor`. A scalar. If the scalar is not a boolean, the scalar is 825 converted to a boolean according to the following rule: if the scalar is a 826 numerical value, non-zero means True and zero means False; if the scalar 827 is a string, non-empty means True and empty means False. 828 inputs: A list of input tensors. 829 then_branch: A function takes 'inputs' and returns a list of tensors, whose 830 types are the same as what else_branch returns. 831 else_branch: A function takes 'inputs' and returns a list of tensors. whose 832 types are the same as what then_branch returns. 833 name: A name for the operation (optional). 834 835 Returns: 836 A list of tensors returned by either then_branch(inputs) 837 or else_branch(inputs). 838 """ 839 # pylint: disable=protected-access 840 # Handle the Defun case until users have transitioned to tf.function. Note 841 # that composites may need to be re-packed by the caller. 842 if isinstance(then_branch, function._DefinedFunction): 843 tlist = [_.type for _ in then_branch.definition.signature.output_arg] 844 return gen_functional_ops._if( 845 cond, inputs, tlist, then_branch, else_branch, name=name) 846 847 # We assume that `then_branch` is a ConcreteFunction here. 848 then_out = then_branch.structured_outputs 849 else_out = else_branch.structured_outputs 850 851 # Ensure then/else are the same type of composites to avoid an invalid call 852 # to pack_sequence_as later on. 853 nest.assert_same_structure(then_out, else_out, expand_composites=True) 854 855 tlist = nest.flatten(then_branch.output_dtypes) 856 ret = gen_functional_ops._if( 857 cond, inputs, tlist, then_branch, else_branch, name=name) 858 859 # Re-pack the outputs to restore any CompositeTensors 860 return nest.pack_sequence_as(then_out, ret, expand_composites=True) 861 862 863def Gradient(inputs, f, name=None): 864 r"""Computes the gradient function for function f via backpropagation. 865 866 Args: 867 inputs: A list of tensors of size N + M. 868 f: The function we want to compute the gradient for. The function 'f' must 869 be a numerical function which takes N inputs and produces M outputs. Its 870 gradient function 'g', which is a function taking N + M inputs and 871 produces N outputs. I.e. if we have (y1, y2, ..., yM) = f(x1, x2, ..., 872 xN), then, g is (dL/dx1, dL/dx2, ..., dL/dxN) = g(x1, x2, ..., xN, dL/dy1, 873 dL/dy2, ..., dL/dyM), where L is a scalar-value function of (x1, x2, ..., 874 xN) (e.g., the loss function). dL/dxi is the partial derivative of L with 875 respect to xi. 876 name: A name for the operation (optional). 877 878 Returns: 879 A list of tensors of size N. 880 """ 881 # TODO(zhifengc): Pretty-print the above spec in latex. 882 # TODO(zhfiengc): Needs some math expert to say the comment above better. 883 tlist = [_.type for _ in f.definition.signature.input_arg] 884 return symbolic_gradient(input=inputs, Tout=tlist, f=f, name=name) 885 886 887def _GetInputDtypes(func): 888 """Returns the input dtypes of func, excluding dtypes for captured inputs.""" 889 if isinstance(func, function._DefinedFunction): # pylint: disable=protected-access 890 return func.declared_input_types 891 892 # We assume that `func` is a ConcreteFunction here, but we are not able to 893 # verify since importing eager function library will cause cyclic dependence. 894 # 895 # ConcreteFunction.inputs includes captured inputs. 896 num_non_captured_inputs = len(func.inputs) - len(func.captured_inputs) 897 inputs_without_captured = func.inputs[:num_non_captured_inputs] 898 return [t.dtype for t in inputs_without_captured] 899 900 901def _LoopBodyCaptureWrapper(func): 902 """Returns a wrapper for `func` that handles loop-carried captured inputs.""" 903 904 @function.Defun(*_GetInputDtypes(func), func_name="%s_Wrapper" % func.name) 905 def Wrapper(*args): 906 """A wrapper that handles loop-carried captured inputs.""" 907 result = func(*args) 908 extra_args = tuple(function.get_extra_args()) 909 # Nullary functions return an Operation. Normal functions can't do this 910 # because their return values are converted to Tensors. 911 if isinstance(result, ops.Operation): 912 return extra_args 913 # Unary functions return a single Tensor value. 914 elif not isinstance(result, (list, tuple)): 915 return (result,) + extra_args 916 # N-ary functions return a tuple of Tensors. 917 else: 918 return result + type(result)(extra_args) 919 920 return Wrapper 921 922 923# pylint: disable=invalid-name,protected-access 924def While(input_, cond, body, name=None, hostmem=None): 925 r"""output = input; While (Cond(output)) { output = Body(output) }. 926 927 Args: 928 input_: A list of `Tensor` objects. A list of input tensors whose types are 929 T. 930 cond: . A function takes 'input' and returns a tensor. If the tensor is a 931 scalar of non-boolean, the scalar is converted to a boolean 932 according to the following rule: if the scalar is a numerical value, 933 non-zero means True and zero means False; if the scalar is a string, 934 non-empty means True and empty means False. If the tensor is not a 935 scalar, non-emptiness means True and False otherwise. 936 body: . A function takes a list of tensors and returns another list tensors. 937 Both lists have the same types as specified by T. 938 name: A name for the operation (optional). 939 hostmem: A list of integer. If i is in the list, input[i] is a host memory 940 tensor. 941 942 Raises: 943 ValueError: if `cond` has implicitly captured inputs or if `cond` and `body` 944 have different signatures. 945 946 Returns: 947 A list of `Tensor` objects. Has the same type as `input`. 948 A list of output tensors whose types are T. 949 """ 950 if cond.captured_inputs: 951 raise ValueError( 952 "The 'cond' argument can not have implicitly captured inputs. Received " 953 f"captured_inputs: {cond.captured_inputs}") 954 955 cond_input_types = _GetInputDtypes(cond) 956 body_input_types = _GetInputDtypes(body) 957 958 if cond_input_types != body_input_types: 959 raise ValueError( 960 "The 'cond' and 'body' signatures do not match. Received: " 961 f"cond_input_types={cond_input_types}, body_input_types=" 962 f"{body_input_types}") 963 964 if body.captured_inputs: 965 cond_dtypes = list(body_input_types) + [ 966 t.dtype for t in body.captured_inputs 967 ] 968 969 @function.Defun(*cond_dtypes, func_name="%s_Wrapper" % cond.name) 970 def CondWrapper(*args): 971 """A wrapper that handles loop-carried captured inputs.""" 972 return cond(*args[:len(body_input_types)]) 973 974 ret = gen_functional_ops._while( 975 input_ + body.captured_inputs, 976 CondWrapper, 977 _LoopBodyCaptureWrapper(body), 978 name=name) 979 # Slice off the loop-carried captured inputs. 980 ret = ret[:-len(body.captured_inputs)] 981 else: 982 ret = gen_functional_ops._while(input_, cond, body, name=name) 983 if hostmem: 984 input_attr = attr_value_pb2.AttrValue() 985 input_attr.list.i.extend(hostmem) 986 ret[0].op._set_attr("_input_hostmem", input_attr) # pylint: disable=protected-access 987 988 output_attr = attr_value_pb2.AttrValue() 989 output_attr.list.i.extend(hostmem) 990 ret[0].op._set_attr("_output_hostmem", output_attr) # pylint: disable=protected-access 991 return ret 992 993 994# b/36459430 995# 996# Ideally, we do not need this rewrite For loop into a While loop. 997# However, today, if a While runs on GPU and the condition returns a 998# boolean, the While kernel crashes. Even if we fix the crash, the 999# bool needs to be copied between GPU and CPU. So, a for loop is much 1000# preferred when running on GPU. 1001# 1002# On the other hand, For op has no directly XLA kernel. So, when we run 1003# a for loop, we need to rewrite it using a While op. 1004# 1005# It should be possible and probably better to write a XLA C++ kernel 1006# implementing the logic in _ForUsingWhile. 1007def _ForUsingWhile(start, 1008 limit, 1009 delta, 1010 inputs, 1011 forbody, 1012 name=None, 1013 hostmem=None): 1014 """Helper to implement a For loop using a While.""" 1015 # To support negative delta (e.g., range(100, 0, -3)), we iterate 1016 # over the range(n) and use iter * delta + start as the real 1017 # iteration index. (e.g., for i in range(34): iter = i * (-3) + 1018 # 100). 1019 d = math_ops.abs(delta) 1020 # XLA on TPUs doesn't support integer division 1021 n = math_ops.cast( 1022 math_ops.cast((math_ops.abs(limit - start) + d - 1), dtypes.float32) / 1023 math_ops.cast(d, dtypes.float32), dtypes.int32) 1024 1025 # Carried loop variables ("extra_args") are implicitly added to the input list 1026 # of the WhileBody function. WhileCond does not call forbody, and so does not 1027 # depend on any of forbody's extra_args. Since WhileCond and WhileBody 1028 # must have identical inputs, we have to augment the cond signature to take 1029 # the same types as the carried loop variables. 1030 body_sig = [dtypes.int32] * 4 + list(forbody.declared_input_types)[1:] 1031 1032 cond_name = "%s_Cond" % forbody.name 1033 1034 @function.Defun(*body_sig, func_name=cond_name) 1035 def WhileCond(i, n, *args): 1036 del args 1037 return i < n 1038 1039 body_name = "%s_Body" % forbody.name 1040 1041 @function.Defun(*body_sig, func_name=body_name) 1042 def WhileBody(i, n, start, delta, *args): 1043 """A While wrapper for forbody that handles loop-carried captured inputs.""" 1044 for_result = forbody(start + i * delta, *args) 1045 # Nullary functions return an Operation. Normal functions can't do this 1046 # because their return values are converted to Tensors. 1047 if isinstance(for_result, ops.Operation): 1048 for_result = () 1049 # Unary functions return a single Tensor value. 1050 elif isinstance(for_result, ops.Tensor): 1051 for_result = (for_result,) 1052 return (i + 1, n, start, delta) + tuple(for_result) 1053 1054 if hostmem is not None: 1055 hostmem = [0, 1, 2, 3] + [(4 + _) for _ in hostmem] 1056 else: 1057 hostmem = [0, 1, 2, 3] 1058 1059 results = While( 1060 input_=[0, n, start, delta] + inputs, 1061 cond=WhileCond, 1062 body=WhileBody, 1063 name=name, 1064 hostmem=hostmem) 1065 # Slice off the loop-carried captured inputs. 1066 return list(results[4:len(results)]) 1067 1068 1069def For(start, 1070 limit, 1071 delta, 1072 inputs, 1073 body, 1074 name=None, 1075 hostmem=None, 1076 rewrite_with_while=None): 1077 r"""out = input; for i in range(start, limit, delta) out = body(i, out). 1078 1079 Args: 1080 start: A `Tensor` of type `int32`. 1081 limit: A `Tensor` of type `int32`. 1082 delta: A `Tensor` of type `int32`. 1083 inputs: A list of `Tensor` objects. A list of input tensors whose types are 1084 T. 1085 body: A function takes a list of tensors and returns another list of 1086 tensors. Both lists have the same types as (int32, T...). 1087 name: A name for the operation (optional). 1088 hostmem: A list of integer. If i is in the list, inputs[i] is a host memory 1089 tensor. In other words, (i+1)-th argument of the body function is 1090 expecting a host memory. 1091 rewrite_with_while: If True, using While op to implement the For. 1092 1093 Returns: 1094 A list of `Tensor` objects. Has the same type as `input`. 1095 A list of output tensors whose types are T. 1096 """ 1097 if rewrite_with_while: 1098 return _ForUsingWhile(start, limit, delta, inputs, body, name, hostmem) 1099 if body.captured_inputs: 1100 ret = gen_functional_ops._for( 1101 start, 1102 limit, 1103 delta, 1104 inputs + body.captured_inputs, 1105 _LoopBodyCaptureWrapper(body), 1106 name=name) 1107 # Slice off the loop-carried captured inputs. 1108 ret = ret[:-len(body.captured_inputs)] 1109 else: 1110 ret = gen_functional_ops._for(start, limit, delta, inputs, body, name=name) 1111 if hostmem: 1112 num_for_params = 3 # start/limit/delta 1113 1114 input_attr = attr_value_pb2.AttrValue() 1115 input_attr.list.i.extend([num_for_params + i for i in hostmem]) 1116 ret[0].op._set_attr("_input_hostmem", input_attr) # pylint: disable=protected-access 1117 1118 output_attr = attr_value_pb2.AttrValue() 1119 output_attr.list.i.extend(hostmem) 1120 ret[0].op._set_attr("_output_hostmem", output_attr) # pylint: disable=protected-access 1121 return ret 1122 1123 1124# pylint: enable=invalid-name,protected-access 1125 1126 1127def partitioned_call(args, 1128 f, 1129 tout=None, 1130 executing_eagerly=None, 1131 config=None, 1132 executor_type=None): 1133 """Executes a function while respecting device annotations. 1134 1135 Currently, only those functions that execute within the same address space 1136 can be executed. 1137 1138 Args: 1139 args: The arguments of the function, including captured inputs. 1140 f: The function to execute; an instance of `_DefinedFunction` or 1141 `_EagerDefinedFunction`. 1142 tout: a list containing the output dtypes enums; if `None`, inferred from 1143 the signature of `f`. 1144 executing_eagerly: (Optional) A boolean indicating whether the context is 1145 executing eagerly. If `None`, fetched from the global context. 1146 config: (Optional) A `tensorflow::ConfigProto` proto, serialized. If `None`, 1147 all optimizations are disabled. Currently only handled for eager defined 1148 functions. 1149 executor_type: (Optional) A string for the name of the executor to be used 1150 in the function call. If not set, or set to an empty string, the default 1151 tensorflow executor will be used. 1152 1153 Returns: 1154 The list of `Tensor`s returned by invoking `f(args)`. If the function does 1155 not return anything, then returns `None` if eager execution is enabled, or 1156 the `Operation` if not. 1157 """ 1158 1159 if tout is None: 1160 tout = tuple(x.type for x in f.definition.signature.output_arg) 1161 1162 if executing_eagerly is None: 1163 executing_eagerly = context.executing_eagerly() 1164 1165 if config is None: 1166 config = function_utils.get_disabled_rewriter_config() 1167 1168 if executor_type is None: 1169 executor_type = "" 1170 1171 if executing_eagerly: 1172 if f.stateful_ops: 1173 outputs = gen_functional_ops.stateful_partitioned_call( 1174 args=args, 1175 Tout=tout, 1176 f=f, 1177 config_proto=config, 1178 executor_type=executor_type) 1179 else: 1180 outputs = gen_functional_ops.partitioned_call( 1181 args=args, 1182 Tout=tout, 1183 f=f, 1184 config_proto=config, 1185 executor_type=executor_type) 1186 return outputs if outputs else None 1187 1188 # The generated binding returns an empty list for functions that don't 1189 # return any Tensors, hence the need to use `create_op` directly. 1190 args = [ops.convert_to_tensor(x) for x in args] 1191 tin_attr = attr_value_pb2.AttrValue( 1192 list=attr_value_pb2.AttrValue.ListValue( 1193 type=[x.dtype.as_datatype_enum for x in args])) 1194 tout_attr = attr_value_pb2.AttrValue( 1195 list=attr_value_pb2.AttrValue.ListValue(type=tout)) 1196 func_attr = attr_value_pb2.AttrValue( 1197 func=attr_value_pb2.NameAttrList(name=f.name)) 1198 executor_type_attr = attr_value_pb2.AttrValue( 1199 s=compat.as_bytes(executor_type)) 1200 1201 # When running in graph mode, the graph and function graphs are optimized 1202 # (i.e. run through grappler) per the session options, so we can disable any 1203 # eager-specific rewriting. 1204 config_proto = attr_value_pb2.AttrValue(s=config) 1205 1206 graph = ops.get_default_graph() 1207 f.add_to_graph(graph) 1208 op_name = "StatefulPartitionedCall" if f.stateful_ops else "PartitionedCall" 1209 1210 # Propagate the attribute indicating the need to compile from function to the 1211 # call itself. 1212 xla_compile_attr = "_XlaMustCompile" 1213 op_attrs = { 1214 "Tin": tin_attr, 1215 "Tout": tout_attr, 1216 "f": func_attr, 1217 "config_proto": config_proto, 1218 "executor_type": executor_type_attr, 1219 } 1220 if xla_compile_attr in f.definition.attr: 1221 op_attrs[xla_compile_attr] = f.definition.attr[xla_compile_attr] 1222 op = graph.create_op(op_name, args, tout, name=op_name, attrs=op_attrs) 1223 outputs = op.outputs 1224 if hasattr(f, "graph"): 1225 _set_read_only_resource_inputs_attr(op, f.graph) 1226 if hasattr(f.graph, "collective_manager_ids_used"): 1227 ops.set_int_list_attr(op, acd.COLLECTIVE_MANAGER_IDS, 1228 f.graph.collective_manager_ids_used) 1229 return outputs if outputs else op 1230 1231 1232def _set_read_only_resource_inputs_attr(op, func_graph): 1233 """Sets the list of resource inputs which are read-only. 1234 1235 This is used by AutomaticControlDependencies. 1236 1237 Args: 1238 op: PartitionedCall Operation. 1239 func_graph: FuncGraph. 1240 """ 1241 read_only_indices = acd.get_read_only_resource_input_indices_graph(func_graph) 1242 ops.set_int_list_attr(op, acd.READ_ONLY_RESOURCE_INPUTS_ATTR, 1243 read_only_indices) 1244