xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/functional_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Functional operations."""
16
17from tensorflow.core.framework import attr_value_pb2
18from tensorflow.python.eager import context
19from tensorflow.python.framework import auto_control_deps_utils as acd
20from tensorflow.python.framework import constant_op
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import function
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import tensor_shape
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import control_flow_ops
27from tensorflow.python.ops import gen_functional_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops import tensor_array_ops
30from tensorflow.python.ops import variable_scope as vs
31# pylint: disable=unused-import
32from tensorflow.python.ops.gen_functional_ops import remote_call
33# pylint: enable=unused-import
34from tensorflow.python.ops.gen_functional_ops import symbolic_gradient
35from tensorflow.python.util import compat
36from tensorflow.python.util import deprecation
37from tensorflow.python.util import dispatch
38from tensorflow.python.util import function_utils
39from tensorflow.python.util import nest
40from tensorflow.python.util.tf_export import tf_export
41
42
43# TODO(yuanbyu, mrry): Handle stride to support sliding windows.
44@tf_export(v1=["foldl"])
45@dispatch.add_dispatch_support
46def foldl(fn,
47          elems,
48          initializer=None,
49          parallel_iterations=10,
50          back_prop=True,
51          swap_memory=False,
52          name=None):
53  """foldl on the list of tensors unpacked from `elems` on dimension 0.
54
55  This foldl operator repeatedly applies the callable `fn` to a sequence
56  of elements from first to last. The elements are made of the tensors
57  unpacked from `elems` on dimension 0. The callable fn takes two tensors as
58  arguments. The first argument is the accumulated value computed from the
59  preceding invocation of fn, and the second is the value at the current
60  position of `elems`. If `initializer` is None, `elems` must contain at least
61  one element, and its first element is used as the initializer.
62
63  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
64  of the result tensor is fn(initializer, values[0]).shape`.
65
66  This method also allows multi-arity `elems` and output of `fn`.  If `elems`
67  is a (possibly nested) list or tuple of tensors, then each of these tensors
68  must have a matching first (unpack) dimension.  The signature of `fn` may
69  match the structure of `elems`.  That is, if `elems` is
70  `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
71  `fn = lambda (t1, [t2, t3, [t4, t5]]):`.
72
73  Args:
74    fn: The callable to be performed.
75    elems: A tensor or (possibly nested) sequence of tensors, each of which will
76      be unpacked along their first dimension.  The nested sequence of the
77      resulting slices will be the first argument to `fn`.
78    initializer: (optional) A tensor or (possibly nested) sequence of tensors,
79      as the initial value for the accumulator.
80    parallel_iterations: (optional) The number of iterations allowed to run in
81      parallel.
82    back_prop: (optional) True enables support for back propagation.
83    swap_memory: (optional) True enables GPU-CPU memory swapping.
84    name: (optional) Name prefix for the returned tensors.
85
86  Returns:
87    A tensor or (possibly nested) sequence of tensors, resulting from applying
88    `fn` consecutively to the list of tensors unpacked from `elems`, from first
89    to last.
90
91  Raises:
92    TypeError: if `fn` is not callable.
93
94  Example:
95    ```python
96    elems = tf.constant([1, 2, 3, 4, 5, 6])
97    sum = foldl(lambda a, x: a + x, elems)
98    # sum == 21
99    ```
100  """
101  if not callable(fn):
102    raise TypeError(
103        f"{fn.__name__} is not callable. Please provide a callable function.")
104
105  def create_ta(elem):
106    return tensor_array_ops.TensorArray(
107        dtype=elem.dtype, size=n, dynamic_size=False,
108        infer_shape=True).unstack(elem)
109
110  in_graph_mode = not context.executing_eagerly()
111  with ops.name_scope(name, "foldl", [elems]):
112    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
113    # supported in Eager
114    if in_graph_mode:
115      # Any get_variable calls in fn will cache the first call locally
116      # and not issue repeated network I/O requests for each iteration.
117      varscope = vs.get_variable_scope()
118      varscope_caching_device_was_none = False
119      if varscope.caching_device is None:
120        # TODO(ebrevdo): Change to using colocate_with here and in other
121        # methods.
122        varscope.set_caching_device(lambda op: op.device)
123        varscope_caching_device_was_none = True
124
125    # Convert elems to tensor array. n may be known statically.
126    elems_flat = [
127        ops.convert_to_tensor(elem, name="elem") for elem in nest.flatten(elems)
128    ]
129    n = (
130        tensor_shape.dimension_value(elems_flat[0].shape[0]) or
131        array_ops.shape(elems_flat[0])[0])
132
133    elems_ta = nest.map_structure(create_ta, elems)
134
135    if initializer is None:
136      a = nest.map_structure(lambda elem: elem.read(0), elems_ta)
137      i = constant_op.constant(1)
138    else:
139      a = initializer
140      i = constant_op.constant(0)
141
142    def compute(i, a):
143      elem_i = nest.map_structure(lambda elem: elem.read(i), elems_ta)
144      a = fn(a, elem_i)
145      return [i + 1, a]
146
147    _, r_a = control_flow_ops.while_loop(
148        lambda i, a: i < n,
149        compute, [i, a],
150        parallel_iterations=parallel_iterations,
151        back_prop=back_prop,
152        swap_memory=swap_memory,
153        maximum_iterations=n)
154
155    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
156    # supported in Eager
157    if in_graph_mode and varscope_caching_device_was_none:
158      varscope.set_caching_device(None)
159
160    return r_a
161
162
163@tf_export("foldl", v1=[])
164@dispatch.add_dispatch_support
165@deprecation.deprecated_arg_values(
166    None,
167    """back_prop=False is deprecated. Consider using tf.stop_gradient instead.
168Instead of:
169results = tf.foldl(fn, elems, back_prop=False)
170Use:
171results = tf.nest.map_structure(tf.stop_gradient, tf.foldl(fn, elems))""",
172    warn_once=True,
173    back_prop=False)
174def foldl_v2(fn,
175             elems,
176             initializer=None,
177             parallel_iterations=10,
178             back_prop=True,
179             swap_memory=False,
180             name=None):
181  """foldl on the list of tensors unpacked from `elems` on dimension 0.
182
183  This foldl operator repeatedly applies the callable `fn` to a sequence
184  of elements from first to last. The elements are made of the tensors
185  unpacked from `elems` on dimension 0. The callable fn takes two tensors as
186  arguments. The first argument is the accumulated value computed from the
187  preceding invocation of fn, and the second is the value at the current
188  position of `elems`. If `initializer` is None, `elems` must contain at least
189  one element, and its first element is used as the initializer.
190
191  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
192  of the result tensor is fn(initializer, values[0]).shape`.
193
194  This method also allows multi-arity `elems` and output of `fn`.  If `elems`
195  is a (possibly nested) list or tuple of tensors, then each of these tensors
196  must have a matching first (unpack) dimension.  The signature of `fn` may
197  match the structure of `elems`.  That is, if `elems` is
198  `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
199  `fn = lambda (t1, [t2, t3, [t4, t5]]):`.
200
201  Args:
202    fn: The callable to be performed.
203    elems: A tensor or (possibly nested) sequence of tensors, each of which will
204      be unpacked along their first dimension.  The nested sequence of the
205      resulting slices will be the first argument to `fn`.
206    initializer: (optional) A tensor or (possibly nested) sequence of tensors,
207      as the initial value for the accumulator.
208    parallel_iterations: (optional) The number of iterations allowed to run in
209      parallel.
210    back_prop: (optional) Deprecated. False disables support for back
211      propagation. Prefer using `tf.stop_gradient` instead.
212    swap_memory: (optional) True enables GPU-CPU memory swapping.
213    name: (optional) Name prefix for the returned tensors.
214
215  Returns:
216    A tensor or (possibly nested) sequence of tensors, resulting from applying
217    `fn` consecutively to the list of tensors unpacked from `elems`, from first
218    to last.
219
220  Raises:
221    TypeError: if `fn` is not callable.
222
223  Example:
224    ```python
225    elems = tf.constant([1, 2, 3, 4, 5, 6])
226    sum = tf.foldl(lambda a, x: a + x, elems)
227    # sum == 21
228    ```
229  """
230  return foldl(
231      fn=fn,
232      elems=elems,
233      initializer=initializer,
234      parallel_iterations=parallel_iterations,
235      back_prop=back_prop,
236      swap_memory=swap_memory,
237      name=name)
238
239
240@tf_export(v1=["foldr"])
241@dispatch.add_dispatch_support
242def foldr(fn,
243          elems,
244          initializer=None,
245          parallel_iterations=10,
246          back_prop=True,
247          swap_memory=False,
248          name=None):
249  """foldr on the list of tensors unpacked from `elems` on dimension 0.
250
251  This foldr operator repeatedly applies the callable `fn` to a sequence
252  of elements from last to first. The elements are made of the tensors
253  unpacked from `elems`. The callable fn takes two tensors as arguments.
254  The first argument is the accumulated value computed from the preceding
255  invocation of fn, and the second is the value at the current position of
256  `elems`. If `initializer` is None, `elems` must contain at least one element,
257  and its first element is used as the initializer.
258
259  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
260  of the result tensor is `fn(initializer, values[0]).shape`.
261
262  This method also allows multi-arity `elems` and output of `fn`.  If `elems`
263  is a (possibly nested) list or tuple of tensors, then each of these tensors
264  must have a matching first (unpack) dimension.  The signature of `fn` may
265  match the structure of `elems`.  That is, if `elems` is
266  `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
267  `fn = lambda (t1, [t2, t3, [t4, t5]]):`.
268
269  Args:
270    fn: The callable to be performed.
271    elems: A tensor or (possibly nested) sequence of tensors, each of which will
272      be unpacked along their first dimension.  The nested sequence of the
273      resulting slices will be the first argument to `fn`.
274    initializer: (optional) A tensor or (possibly nested) sequence of tensors,
275      as the initial value for the accumulator.
276    parallel_iterations: (optional) The number of iterations allowed to run in
277      parallel.
278    back_prop: (optional) True enables support for back propagation.
279    swap_memory: (optional) True enables GPU-CPU memory swapping.
280    name: (optional) Name prefix for the returned tensors.
281
282  Returns:
283    A tensor or (possibly nested) sequence of tensors, resulting from applying
284    `fn` consecutively to the list of tensors unpacked from `elems`, from last
285    to first.
286
287  Raises:
288    TypeError: if `fn` is not callable.
289
290  Example:
291    ```python
292    elems = [1, 2, 3, 4, 5, 6]
293    sum = foldr(lambda a, x: a + x, elems)
294    # sum == 21
295    ```
296  """
297  if not callable(fn):
298    raise TypeError(
299        f"{fn.__name__} is not callable. Please provide a callable function.")
300
301  def create_ta(elem):
302    return tensor_array_ops.TensorArray(
303        dtype=elem.dtype, size=n, dynamic_size=False,
304        infer_shape=True).unstack(elem)
305
306  in_graph_mode = not context.executing_eagerly()
307  with ops.name_scope(name, "foldr", [elems]):
308    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
309    # supported in Eager
310    if in_graph_mode:
311      # Any get_variable calls in fn will cache the first call locally and not
312      # issue repeated network I/O requests for each iteration.
313      varscope = vs.get_variable_scope()
314      varscope_caching_device_was_none = False
315      if varscope.caching_device is None:
316        # TODO(ebrevdo): Change to using colocate_with here and in other
317        # methods.
318        varscope.set_caching_device(lambda op: op.device)
319        varscope_caching_device_was_none = True
320
321    # Convert elems to tensor array. n may be known statically.
322    elems_flat = [
323        ops.convert_to_tensor(elem, name="elem") for elem in nest.flatten(elems)
324    ]
325    n = (
326        tensor_shape.dimension_value(elems_flat[0].shape[0]) or
327        array_ops.shape(elems_flat[0])[0])
328
329    elems_ta = nest.map_structure(create_ta, elems)
330
331    if initializer is None:
332      i = n - 1
333      a = nest.map_structure(lambda elem: elem.read(i), elems_ta)
334    else:
335      i = n
336      a = initializer
337
338    def compute(i, a):
339      i -= 1
340      elem = nest.map_structure(lambda elem: elem.read(i), elems_ta)
341      a_out = fn(a, elem)
342      return [i, a_out]
343
344    _, r_a = control_flow_ops.while_loop(
345        lambda i, a: i > 0,
346        compute, [i, a],
347        parallel_iterations=parallel_iterations,
348        back_prop=back_prop,
349        swap_memory=swap_memory,
350        maximum_iterations=n)
351
352    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
353    # supported in Eager
354    if in_graph_mode and varscope_caching_device_was_none:
355      varscope.set_caching_device(None)
356
357    return r_a
358
359
360@tf_export("foldr", v1=[])
361@dispatch.add_dispatch_support
362@deprecation.deprecated_arg_values(
363    None,
364    """back_prop=False is deprecated. Consider using tf.stop_gradient instead.
365Instead of:
366results = tf.foldr(fn, elems, back_prop=False)
367Use:
368results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))""",
369    warn_once=True,
370    back_prop=False)
371def foldr_v2(fn,
372             elems,
373             initializer=None,
374             parallel_iterations=10,
375             back_prop=True,
376             swap_memory=False,
377             name=None):
378  """foldr on the list of tensors unpacked from `elems` on dimension 0.
379
380  This foldr operator repeatedly applies the callable `fn` to a sequence
381  of elements from last to first. The elements are made of the tensors
382  unpacked from `elems`. The callable fn takes two tensors as arguments.
383  The first argument is the accumulated value computed from the preceding
384  invocation of fn, and the second is the value at the current position of
385  `elems`. If `initializer` is None, `elems` must contain at least one element,
386  and its first element is used as the initializer.
387
388  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
389  of the result tensor is `fn(initializer, values[0]).shape`.
390
391  This method also allows multi-arity `elems` and output of `fn`.  If `elems`
392  is a (possibly nested) list or tuple of tensors, then each of these tensors
393  must have a matching first (unpack) dimension.  The signature of `fn` may
394  match the structure of `elems`.  That is, if `elems` is
395  `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
396  `fn = lambda (t1, [t2, t3, [t4, t5]]):`.
397
398  Args:
399    fn: The callable to be performed.
400    elems: A tensor or (possibly nested) sequence of tensors, each of which will
401      be unpacked along their first dimension.  The nested sequence of the
402      resulting slices will be the first argument to `fn`.
403    initializer: (optional) A tensor or (possibly nested) sequence of tensors,
404      as the initial value for the accumulator.
405    parallel_iterations: (optional) The number of iterations allowed to run in
406      parallel.
407    back_prop: (optional) Deprecated. False disables support for back
408      propagation. Prefer using `tf.stop_gradient` instead.
409    swap_memory: (optional) True enables GPU-CPU memory swapping.
410    name: (optional) Name prefix for the returned tensors.
411
412  Returns:
413    A tensor or (possibly nested) sequence of tensors, resulting from applying
414    `fn` consecutively to the list of tensors unpacked from `elems`, from last
415    to first.
416
417  Raises:
418    TypeError: if `fn` is not callable.
419
420  Example:
421    ```python
422    elems = [1, 2, 3, 4, 5, 6]
423    sum = tf.foldr(lambda a, x: a + x, elems)
424    # sum == 21
425    ```
426  """
427  return foldr(
428      fn=fn,
429      elems=elems,
430      initializer=initializer,
431      parallel_iterations=parallel_iterations,
432      back_prop=back_prop,
433      swap_memory=swap_memory,
434      name=name)
435
436
437@tf_export(v1=["scan"])
438@dispatch.add_dispatch_support
439def scan(fn,
440         elems,
441         initializer=None,
442         parallel_iterations=10,
443         back_prop=True,
444         swap_memory=False,
445         infer_shape=True,
446         reverse=False,
447         name=None):
448  """scan on the list of tensors unpacked from `elems` on dimension 0.
449
450  See also `tf.map_fn`.
451
452  The simplest version of `scan` repeatedly applies the callable `fn` to a
453  sequence of elements from first to last. The elements are made of the tensors
454  unpacked from `elems` on dimension 0. The callable fn takes two tensors as
455  arguments. The first argument is the accumulated value computed from the
456  preceding invocation of fn, and the second is the value at the current
457  position of `elems`. If `initializer` is None, `elems` must contain at least
458  one element, and its first element is used as the initializer.
459
460  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
461  of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`.
462  If reverse=True, it's fn(initializer, values[-1]).shape.
463
464  This method also allows multi-arity `elems` and accumulator.  If `elems`
465  is a (possibly nested) list or tuple of tensors, then each of these tensors
466  must have a matching first (unpack) dimension.  The second argument of
467  `fn` must match the structure of `elems`.
468
469  If no `initializer` is provided, the output structure and dtypes of `fn`
470  are assumed to be the same as its input; and in this case, the first
471  argument of `fn` must match the structure of `elems`.
472
473  If an `initializer` is provided, then the output of `fn` must have the same
474  structure as `initializer`; and the first argument of `fn` must match
475  this structure.
476
477  For example, if `elems` is `(t1, [t2, t3])` and `initializer` is
478  `[i1, i2]` then an appropriate signature for `fn` in `python2` is:
479  `fn = lambda (acc_p1, acc_p2), (t1, [t2, t3]):` and `fn` must return a list,
480  `[acc_n1, acc_n2]`.  An alternative correct signature for `fn`, and the
481   one that works in `python3`, is:
482  `fn = lambda a, t:`, where `a` and `t` correspond to the input tuples.
483
484  Args:
485    fn: The callable to be performed.  It accepts two arguments.  The first will
486      have the same structure as `initializer` if one is provided, otherwise it
487      will have the same structure as `elems`.  The second will have the same
488      (possibly nested) structure as `elems`.  Its output must have the same
489      structure as `initializer` if one is provided, otherwise it must have the
490      same structure as `elems`.
491    elems: A tensor or (possibly nested) sequence of tensors, each of which will
492      be unpacked along their first dimension.  The nested sequence of the
493      resulting slices will be the first argument to `fn`.
494    initializer: (optional) A tensor or (possibly nested) sequence of tensors,
495      initial value for the accumulator, and the expected output type of `fn`.
496    parallel_iterations: (optional) The number of iterations allowed to run in
497      parallel.
498    back_prop: (optional) True enables support for back propagation.
499    swap_memory: (optional) True enables GPU-CPU memory swapping.
500    infer_shape: (optional) False disables tests for consistent output shapes.
501    reverse: (optional) True scans the tensor last to first (instead of first to
502      last).
503    name: (optional) Name prefix for the returned tensors.
504
505  Returns:
506    A tensor or (possibly nested) sequence of tensors.  Each tensor packs the
507    results of applying `fn` to tensors unpacked from `elems` along the first
508    dimension, and the previous accumulator value(s), from first to last (or
509    last to first, if `reverse=True`).
510
511  Raises:
512    TypeError: if `fn` is not callable or the structure of the output of
513      `fn` and `initializer` do not match.
514    ValueError: if the lengths of the output of `fn` and `initializer`
515      do not match.
516
517  Examples:
518    ```python
519    elems = np.array([1, 2, 3, 4, 5, 6])
520    sum = scan(lambda a, x: a + x, elems)
521    # sum == [1, 3, 6, 10, 15, 21]
522    sum = scan(lambda a, x: a + x, elems, reverse=True)
523    # sum == [21, 20, 18, 15, 11, 6]
524    ```
525
526    ```python
527    elems = np.array([1, 2, 3, 4, 5, 6])
528    initializer = np.array(0)
529    sum_one = scan(
530        lambda a, x: x[0] - x[1] + a, (elems + 1, elems), initializer)
531    # sum_one == [1, 2, 3, 4, 5, 6]
532    ```
533
534    ```python
535    elems = np.array([1, 0, 0, 0, 0, 0])
536    initializer = (np.array(0), np.array(1))
537    fibonaccis = scan(lambda a, _: (a[1], a[0] + a[1]), elems, initializer)
538    # fibonaccis == ([1, 1, 2, 3, 5, 8], [1, 2, 3, 5, 8, 13])
539    ```
540  """
541  if not callable(fn):
542    raise TypeError(
543        f"{fn.__name__} is not callable. Please provide a callable function.")
544
545  input_is_sequence = nest.is_nested(elems)
546  input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x]
547
548  def input_pack(x):
549    return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0]
550
551  if initializer is None:
552    output_is_sequence = input_is_sequence
553    output_flatten = input_flatten
554    output_pack = input_pack
555  else:
556    output_is_sequence = nest.is_nested(initializer)
557    output_flatten = lambda x: nest.flatten(x) if output_is_sequence else [x]
558
559    def output_pack(x):
560      return (nest.pack_sequence_as(initializer, x)
561              if output_is_sequence else x[0])
562
563  elems_flat = input_flatten(elems)
564
565  in_graph_mode = not context.executing_eagerly()
566  with ops.name_scope(name, "scan", elems_flat):
567    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
568    # supported in Eager
569    if in_graph_mode:
570      # Any get_variable calls in fn will cache the first call locally
571      # and not issue repeated network I/O requests for each iteration.
572      varscope = vs.get_variable_scope()
573      varscope_caching_device_was_none = False
574      if varscope.caching_device is None:
575        # TODO(ebrevdo): Change to using colocate_with here and in other
576        # methods.
577        varscope.set_caching_device(lambda op: op.device)
578        varscope_caching_device_was_none = True
579
580    # Convert elems to tensor array.
581    elems_flat = [
582        ops.convert_to_tensor(elem, name="elem") for elem in elems_flat
583    ]
584
585    # Convert elems to tensor array. n may be known statically.
586    n = tensor_shape.dimension_value(elems_flat[0].shape[0])
587    if n is None:
588      n = array_ops.shape(elems_flat[0])[0]
589
590    # TensorArrays are always flat
591    elems_ta = [
592        tensor_array_ops.TensorArray(
593            dtype=elem.dtype,
594            size=n,
595            dynamic_size=False,
596            element_shape=elem.shape[1:],
597            infer_shape=True) for elem in elems_flat
598    ]
599    # Unpack elements
600    elems_ta = [
601        elem_ta.unstack(elem) for elem_ta, elem in zip(elems_ta, elems_flat)
602    ]
603
604    if initializer is None:
605      a_flat = [elem.read(n - 1 if reverse else 0) for elem in elems_ta]
606      i = 1
607    else:
608      initializer_flat = output_flatten(initializer)
609      a_flat = [ops.convert_to_tensor(init) for init in initializer_flat]
610      i = 0
611
612    # Create a tensor array to store the intermediate values.
613    accs_ta = [
614        tensor_array_ops.TensorArray(
615            dtype=init.dtype,
616            size=n,
617            element_shape=init.shape if infer_shape else None,
618            dynamic_size=False,
619            infer_shape=infer_shape) for init in a_flat
620    ]
621
622    if initializer is None:
623      accs_ta = [
624          acc_ta.write(n - 1 if reverse else 0, a)
625          for (acc_ta, a) in zip(accs_ta, a_flat)
626      ]
627
628    def compute(i, a_flat, tas):
629      """The loop body of scan.
630
631      Args:
632        i: the loop counter.
633        a_flat: the accumulator value(s), flattened.
634        tas: the output accumulator TensorArray(s), flattened.
635
636      Returns:
637        [i + 1, a_flat, tas]: the updated counter + new accumulator values +
638          updated TensorArrays
639
640      Raises:
641        TypeError: if initializer and fn() output structure do not match
642        ValueType: if initializer and fn() output lengths do not match
643      """
644      packed_elems = input_pack([elem_ta.read(i) for elem_ta in elems_ta])
645      packed_a = output_pack(a_flat)
646      a_out = fn(packed_a, packed_elems)
647      nest.assert_same_structure(elems if initializer is None else initializer,
648                                 a_out)
649      flat_a_out = output_flatten(a_out)
650      tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_a_out)]
651      if reverse:
652        next_i = i - 1
653      else:
654        next_i = i + 1
655      return (next_i, flat_a_out, tas)
656
657    if reverse:
658      initial_i = n - 1 - i
659      condition = lambda i, _1, _2: i >= 0
660    else:
661      initial_i = i
662      condition = lambda i, _1, _2: i < n
663    _, _, r_a = control_flow_ops.while_loop(
664        condition,
665        compute, (initial_i, a_flat, accs_ta),
666        parallel_iterations=parallel_iterations,
667        back_prop=back_prop,
668        swap_memory=swap_memory,
669        maximum_iterations=n)
670
671    results_flat = [r.stack() for r in r_a]
672
673    n_static = tensor_shape.Dimension(
674        tensor_shape.dimension_value(
675            elems_flat[0].get_shape().with_rank_at_least(1)[0]))
676    for elem in elems_flat[1:]:
677      n_static.assert_is_compatible_with(
678          tensor_shape.Dimension(
679              tensor_shape.dimension_value(
680                  elem.get_shape().with_rank_at_least(1)[0])))
681    for r in results_flat:
682      r.set_shape(
683          tensor_shape.TensorShape(n_static).concatenate(r.get_shape()[1:]))
684
685    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
686    # supported in Eager
687    if in_graph_mode and varscope_caching_device_was_none:
688      varscope.set_caching_device(None)
689
690    return output_pack(results_flat)
691
692
693@tf_export("scan", v1=[])
694@dispatch.add_dispatch_support
695@deprecation.deprecated_arg_values(
696    None,
697    """back_prop=False is deprecated. Consider using tf.stop_gradient instead.
698Instead of:
699results = tf.scan(fn, elems, back_prop=False)
700Use:
701results = tf.nest.map_structure(tf.stop_gradient, tf.scan(fn, elems))""",
702    warn_once=True,
703    back_prop=False)
704def scan_v2(fn,
705            elems,
706            initializer=None,
707            parallel_iterations=10,
708            back_prop=True,
709            swap_memory=False,
710            infer_shape=True,
711            reverse=False,
712            name=None):
713  """scan on the list of tensors unpacked from `elems` on dimension 0.
714
715  The simplest version of `scan` repeatedly applies the callable `fn` to a
716  sequence of elements from first to last. The elements are made of the tensors
717  unpacked from `elems` on dimension 0. The callable fn takes two tensors as
718  arguments. The first argument is the accumulated value computed from the
719  preceding invocation of fn, and the second is the value at the current
720  position of `elems`. If `initializer` is None, `elems` must contain at least
721  one element, and its first element is used as the initializer.
722
723  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
724  of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`.
725  If reverse=True, it's fn(initializer, values[-1]).shape.
726
727  This method also allows multi-arity `elems` and accumulator.  If `elems`
728  is a (possibly nested) list or tuple of tensors, then each of these tensors
729  must have a matching first (unpack) dimension.  The second argument of
730  `fn` must match the structure of `elems`.
731
732  If no `initializer` is provided, the output structure and dtypes of `fn`
733  are assumed to be the same as its input; and in this case, the first
734  argument of `fn` must match the structure of `elems`.
735
736  If an `initializer` is provided, then the output of `fn` must have the same
737  structure as `initializer`; and the first argument of `fn` must match
738  this structure.
739
740  For example, if `elems` is `(t1, [t2, t3])` and `initializer` is
741  `[i1, i2]` then an appropriate signature for `fn` in `python2` is:
742  `fn = lambda (acc_p1, acc_p2), (t1, [t2, t3]):` and `fn` must return a list,
743  `[acc_n1, acc_n2]`.  An alternative correct signature for `fn`, and the
744   one that works in `python3`, is:
745  `fn = lambda a, t:`, where `a` and `t` correspond to the input tuples.
746
747  Args:
748    fn: The callable to be performed.  It accepts two arguments.  The first will
749      have the same structure as `initializer` if one is provided, otherwise it
750      will have the same structure as `elems`.  The second will have the same
751      (possibly nested) structure as `elems`.  Its output must have the same
752      structure as `initializer` if one is provided, otherwise it must have the
753      same structure as `elems`.
754    elems: A tensor or (possibly nested) sequence of tensors, each of which will
755      be unpacked along their first dimension.  The nested sequence of the
756      resulting slices will be the first argument to `fn`.
757    initializer: (optional) A tensor or (possibly nested) sequence of tensors,
758      initial value for the accumulator, and the expected output type of `fn`.
759    parallel_iterations: (optional) The number of iterations allowed to run in
760      parallel.
761    back_prop: (optional) Deprecated. False disables support for back
762      propagation. Prefer using `tf.stop_gradient` instead.
763    swap_memory: (optional) True enables GPU-CPU memory swapping.
764    infer_shape: (optional) False disables tests for consistent output shapes.
765    reverse: (optional) True scans the tensor last to first (instead of first to
766      last).
767    name: (optional) Name prefix for the returned tensors.
768
769  Returns:
770    A tensor or (possibly nested) sequence of tensors.  Each tensor packs the
771    results of applying `fn` to tensors unpacked from `elems` along the first
772    dimension, and the previous accumulator value(s), from first to last (or
773    last to first, if `reverse=True`).
774
775  Raises:
776    TypeError: if `fn` is not callable or the structure of the output of
777      `fn` and `initializer` do not match.
778    ValueError: if the lengths of the output of `fn` and `initializer`
779      do not match.
780
781  Examples:
782    ```python
783    elems = np.array([1, 2, 3, 4, 5, 6])
784    sum = scan(lambda a, x: a + x, elems)
785    # sum == [1, 3, 6, 10, 15, 21]
786    sum = scan(lambda a, x: a + x, elems, reverse=True)
787    # sum == [21, 20, 18, 15, 11, 6]
788    ```
789
790    ```python
791    elems = np.array([1, 2, 3, 4, 5, 6])
792    initializer = np.array(0)
793    sum_one = scan(
794        lambda a, x: x[0] - x[1] + a, (elems + 1, elems), initializer)
795    # sum_one == [1, 2, 3, 4, 5, 6]
796    ```
797
798    ```python
799    elems = np.array([1, 0, 0, 0, 0, 0])
800    initializer = (np.array(0), np.array(1))
801    fibonaccis = scan(lambda a, _: (a[1], a[0] + a[1]), elems, initializer)
802    # fibonaccis == ([1, 1, 2, 3, 5, 8], [1, 2, 3, 5, 8, 13])
803    ```
804  """
805  return scan(
806      fn=fn,
807      elems=elems,
808      initializer=initializer,
809      parallel_iterations=parallel_iterations,
810      back_prop=back_prop,
811      swap_memory=swap_memory,
812      infer_shape=infer_shape,
813      reverse=reverse,
814      name=name)
815
816
817# pylint: disable=invalid-name
818def If(cond, inputs, then_branch, else_branch, name=None):
819  r"""output = Cond(inputs) ?
820
821  then_branch(inputs) : else_branch(inputs).
822
823  Args:
824    cond: A `Tensor`. A scalar. If the scalar is not a boolean, the scalar is
825      converted to a boolean according to the following rule: if the scalar is a
826        numerical value, non-zero means True and zero means False; if the scalar
827        is a string, non-empty means True and empty means False.
828    inputs: A list of input tensors.
829    then_branch: A function takes 'inputs' and returns a list of tensors, whose
830      types are the same as what else_branch returns.
831    else_branch: A function takes 'inputs' and returns a list of tensors. whose
832      types are the same as what then_branch returns.
833    name: A name for the operation (optional).
834
835  Returns:
836    A list of tensors returned by either then_branch(inputs)
837    or else_branch(inputs).
838  """
839  # pylint: disable=protected-access
840  # Handle the Defun case until users have transitioned to tf.function. Note
841  # that composites may need to be re-packed by the caller.
842  if isinstance(then_branch, function._DefinedFunction):
843    tlist = [_.type for _ in then_branch.definition.signature.output_arg]
844    return gen_functional_ops._if(
845        cond, inputs, tlist, then_branch, else_branch, name=name)
846
847  # We assume that `then_branch` is a ConcreteFunction here.
848  then_out = then_branch.structured_outputs
849  else_out = else_branch.structured_outputs
850
851  # Ensure then/else are the same type of composites to avoid an invalid call
852  # to pack_sequence_as later on.
853  nest.assert_same_structure(then_out, else_out, expand_composites=True)
854
855  tlist = nest.flatten(then_branch.output_dtypes)
856  ret = gen_functional_ops._if(
857      cond, inputs, tlist, then_branch, else_branch, name=name)
858
859  # Re-pack the outputs to restore any CompositeTensors
860  return nest.pack_sequence_as(then_out, ret, expand_composites=True)
861
862
863def Gradient(inputs, f, name=None):
864  r"""Computes the gradient function for function f via backpropagation.
865
866  Args:
867    inputs: A list of tensors of size N + M.
868    f: The function we want to compute the gradient for.  The function 'f' must
869      be a numerical function which takes N inputs and produces M outputs. Its
870      gradient function 'g', which is  a function taking N + M inputs and
871      produces N outputs.  I.e. if we have (y1, y2, ..., yM) = f(x1, x2, ...,
872      xN), then, g is (dL/dx1, dL/dx2, ..., dL/dxN) = g(x1, x2, ..., xN, dL/dy1,
873      dL/dy2, ..., dL/dyM),  where L is a scalar-value function of (x1, x2, ...,
874      xN) (e.g., the loss function). dL/dxi is the partial derivative of L with
875      respect to xi.
876    name: A name for the operation (optional).
877
878  Returns:
879    A list of tensors of size N.
880  """
881  # TODO(zhifengc): Pretty-print the above spec in latex.
882  # TODO(zhfiengc): Needs some math expert to say the comment above better.
883  tlist = [_.type for _ in f.definition.signature.input_arg]
884  return symbolic_gradient(input=inputs, Tout=tlist, f=f, name=name)
885
886
887def _GetInputDtypes(func):
888  """Returns the input dtypes of func, excluding dtypes for captured inputs."""
889  if isinstance(func, function._DefinedFunction):  # pylint: disable=protected-access
890    return func.declared_input_types
891
892  # We assume that `func` is a ConcreteFunction here, but we are not able to
893  # verify since importing eager function library will cause cyclic dependence.
894  #
895  # ConcreteFunction.inputs includes captured inputs.
896  num_non_captured_inputs = len(func.inputs) - len(func.captured_inputs)
897  inputs_without_captured = func.inputs[:num_non_captured_inputs]
898  return [t.dtype for t in inputs_without_captured]
899
900
901def _LoopBodyCaptureWrapper(func):
902  """Returns a wrapper for `func` that handles loop-carried captured inputs."""
903
904  @function.Defun(*_GetInputDtypes(func), func_name="%s_Wrapper" % func.name)
905  def Wrapper(*args):
906    """A wrapper that handles loop-carried captured inputs."""
907    result = func(*args)
908    extra_args = tuple(function.get_extra_args())
909    # Nullary functions return an Operation. Normal functions can't do this
910    # because their return values are converted to Tensors.
911    if isinstance(result, ops.Operation):
912      return extra_args
913    # Unary functions return a single Tensor value.
914    elif not isinstance(result, (list, tuple)):
915      return (result,) + extra_args
916    # N-ary functions return a tuple of Tensors.
917    else:
918      return result + type(result)(extra_args)
919
920  return Wrapper
921
922
923# pylint: disable=invalid-name,protected-access
924def While(input_, cond, body, name=None, hostmem=None):
925  r"""output = input; While (Cond(output)) { output = Body(output) }.
926
927  Args:
928    input_: A list of `Tensor` objects. A list of input tensors whose types are
929      T.
930    cond: . A function takes 'input' and returns a tensor.  If the tensor is a
931      scalar of non-boolean, the scalar is converted to a boolean
932      according to the following rule: if the scalar is a numerical value,
933        non-zero means True and zero means False; if the scalar is a string,
934        non-empty means True and empty means False. If the tensor is not a
935        scalar, non-emptiness means True and False otherwise.
936    body: . A function takes a list of tensors and returns another list tensors.
937      Both lists have the same types as specified by T.
938    name: A name for the operation (optional).
939    hostmem: A list of integer. If i is in the list, input[i] is a host memory
940      tensor.
941
942  Raises:
943    ValueError: if `cond` has implicitly captured inputs or if `cond` and `body`
944      have different signatures.
945
946  Returns:
947    A list of `Tensor` objects. Has the same type as `input`.
948    A list of output tensors whose types are T.
949  """
950  if cond.captured_inputs:
951    raise ValueError(
952        "The 'cond' argument can not have implicitly captured inputs. Received "
953        f"captured_inputs: {cond.captured_inputs}")
954
955  cond_input_types = _GetInputDtypes(cond)
956  body_input_types = _GetInputDtypes(body)
957
958  if cond_input_types != body_input_types:
959    raise ValueError(
960        "The 'cond' and 'body' signatures do not match. Received: "
961        f"cond_input_types={cond_input_types}, body_input_types="
962        f"{body_input_types}")
963
964  if body.captured_inputs:
965    cond_dtypes = list(body_input_types) + [
966        t.dtype for t in body.captured_inputs
967    ]
968
969    @function.Defun(*cond_dtypes, func_name="%s_Wrapper" % cond.name)
970    def CondWrapper(*args):
971      """A wrapper that handles loop-carried captured inputs."""
972      return cond(*args[:len(body_input_types)])
973
974    ret = gen_functional_ops._while(
975        input_ + body.captured_inputs,
976        CondWrapper,
977        _LoopBodyCaptureWrapper(body),
978        name=name)
979    # Slice off the loop-carried captured inputs.
980    ret = ret[:-len(body.captured_inputs)]
981  else:
982    ret = gen_functional_ops._while(input_, cond, body, name=name)
983  if hostmem:
984    input_attr = attr_value_pb2.AttrValue()
985    input_attr.list.i.extend(hostmem)
986    ret[0].op._set_attr("_input_hostmem", input_attr)  # pylint: disable=protected-access
987
988    output_attr = attr_value_pb2.AttrValue()
989    output_attr.list.i.extend(hostmem)
990    ret[0].op._set_attr("_output_hostmem", output_attr)  # pylint: disable=protected-access
991  return ret
992
993
994# b/36459430
995#
996# Ideally, we do not need this rewrite For loop into a While loop.
997# However, today, if a While runs on GPU and the condition returns a
998# boolean, the While kernel crashes. Even if we fix the crash, the
999# bool needs to be copied between GPU and CPU. So, a for loop is much
1000# preferred when running on GPU.
1001#
1002# On the other hand, For op has no directly XLA kernel. So, when we run
1003# a for loop, we need to rewrite it using a While op.
1004#
1005# It should be possible and probably better to write a XLA C++ kernel
1006# implementing the logic in _ForUsingWhile.
1007def _ForUsingWhile(start,
1008                   limit,
1009                   delta,
1010                   inputs,
1011                   forbody,
1012                   name=None,
1013                   hostmem=None):
1014  """Helper to implement a For loop using a While."""
1015  # To support negative delta (e.g., range(100, 0, -3)), we iterate
1016  # over the range(n) and use iter * delta + start as the real
1017  # iteration index. (e.g., for i in range(34): iter = i * (-3) +
1018  # 100).
1019  d = math_ops.abs(delta)
1020  # XLA on TPUs doesn't support integer division
1021  n = math_ops.cast(
1022      math_ops.cast((math_ops.abs(limit - start) + d - 1), dtypes.float32) /
1023      math_ops.cast(d, dtypes.float32), dtypes.int32)
1024
1025  # Carried loop variables ("extra_args") are implicitly added to the input list
1026  # of the WhileBody function. WhileCond does not call forbody, and so does not
1027  # depend on any of forbody's extra_args. Since WhileCond and WhileBody
1028  # must have identical inputs, we have to augment the cond signature to take
1029  # the same types as the carried loop variables.
1030  body_sig = [dtypes.int32] * 4 + list(forbody.declared_input_types)[1:]
1031
1032  cond_name = "%s_Cond" % forbody.name
1033
1034  @function.Defun(*body_sig, func_name=cond_name)
1035  def WhileCond(i, n, *args):
1036    del args
1037    return i < n
1038
1039  body_name = "%s_Body" % forbody.name
1040
1041  @function.Defun(*body_sig, func_name=body_name)
1042  def WhileBody(i, n, start, delta, *args):
1043    """A While wrapper for forbody that handles loop-carried captured inputs."""
1044    for_result = forbody(start + i * delta, *args)
1045    # Nullary functions return an Operation. Normal functions can't do this
1046    # because their return values are converted to Tensors.
1047    if isinstance(for_result, ops.Operation):
1048      for_result = ()
1049    # Unary functions return a single Tensor value.
1050    elif isinstance(for_result, ops.Tensor):
1051      for_result = (for_result,)
1052    return (i + 1, n, start, delta) + tuple(for_result)
1053
1054  if hostmem is not None:
1055    hostmem = [0, 1, 2, 3] + [(4 + _) for _ in hostmem]
1056  else:
1057    hostmem = [0, 1, 2, 3]
1058
1059  results = While(
1060      input_=[0, n, start, delta] + inputs,
1061      cond=WhileCond,
1062      body=WhileBody,
1063      name=name,
1064      hostmem=hostmem)
1065  # Slice off the loop-carried captured inputs.
1066  return list(results[4:len(results)])
1067
1068
1069def For(start,
1070        limit,
1071        delta,
1072        inputs,
1073        body,
1074        name=None,
1075        hostmem=None,
1076        rewrite_with_while=None):
1077  r"""out = input; for i in range(start, limit, delta) out = body(i, out).
1078
1079  Args:
1080    start: A `Tensor` of type `int32`.
1081    limit: A `Tensor` of type `int32`.
1082    delta: A `Tensor` of type `int32`.
1083    inputs: A list of `Tensor` objects. A list of input tensors whose types are
1084      T.
1085    body: A function takes a list of tensors and returns another list of
1086      tensors. Both lists have the same types as (int32, T...).
1087    name: A name for the operation (optional).
1088    hostmem: A list of integer. If i is in the list, inputs[i] is a host memory
1089      tensor. In other words, (i+1)-th argument of the body function is
1090      expecting a host memory.
1091    rewrite_with_while: If True, using While op to implement the For.
1092
1093  Returns:
1094    A list of `Tensor` objects. Has the same type as `input`.
1095    A list of output tensors whose types are T.
1096  """
1097  if rewrite_with_while:
1098    return _ForUsingWhile(start, limit, delta, inputs, body, name, hostmem)
1099  if body.captured_inputs:
1100    ret = gen_functional_ops._for(
1101        start,
1102        limit,
1103        delta,
1104        inputs + body.captured_inputs,
1105        _LoopBodyCaptureWrapper(body),
1106        name=name)
1107    # Slice off the loop-carried captured inputs.
1108    ret = ret[:-len(body.captured_inputs)]
1109  else:
1110    ret = gen_functional_ops._for(start, limit, delta, inputs, body, name=name)
1111  if hostmem:
1112    num_for_params = 3  # start/limit/delta
1113
1114    input_attr = attr_value_pb2.AttrValue()
1115    input_attr.list.i.extend([num_for_params + i for i in hostmem])
1116    ret[0].op._set_attr("_input_hostmem", input_attr)  # pylint: disable=protected-access
1117
1118    output_attr = attr_value_pb2.AttrValue()
1119    output_attr.list.i.extend(hostmem)
1120    ret[0].op._set_attr("_output_hostmem", output_attr)  # pylint: disable=protected-access
1121  return ret
1122
1123
1124# pylint: enable=invalid-name,protected-access
1125
1126
1127def partitioned_call(args,
1128                     f,
1129                     tout=None,
1130                     executing_eagerly=None,
1131                     config=None,
1132                     executor_type=None):
1133  """Executes a function while respecting device annotations.
1134
1135  Currently, only those functions that execute within the same address space
1136  can be executed.
1137
1138  Args:
1139    args: The arguments of the function, including captured inputs.
1140    f: The function to execute; an instance of `_DefinedFunction` or
1141      `_EagerDefinedFunction`.
1142    tout: a list containing the output dtypes enums; if `None`, inferred from
1143      the signature of `f`.
1144    executing_eagerly: (Optional) A boolean indicating whether the context is
1145      executing eagerly. If `None`, fetched from the global context.
1146    config: (Optional) A `tensorflow::ConfigProto` proto, serialized. If `None`,
1147      all optimizations are disabled. Currently only handled for eager defined
1148      functions.
1149    executor_type: (Optional) A string for the name of the executor to be used
1150      in the function call. If not set, or set to an empty string, the default
1151      tensorflow executor will be used.
1152
1153  Returns:
1154    The list of `Tensor`s returned by invoking `f(args)`. If the function does
1155    not return anything, then returns `None` if eager execution is enabled, or
1156    the `Operation` if not.
1157  """
1158
1159  if tout is None:
1160    tout = tuple(x.type for x in f.definition.signature.output_arg)
1161
1162  if executing_eagerly is None:
1163    executing_eagerly = context.executing_eagerly()
1164
1165  if config is None:
1166    config = function_utils.get_disabled_rewriter_config()
1167
1168  if executor_type is None:
1169    executor_type = ""
1170
1171  if executing_eagerly:
1172    if f.stateful_ops:
1173      outputs = gen_functional_ops.stateful_partitioned_call(
1174          args=args,
1175          Tout=tout,
1176          f=f,
1177          config_proto=config,
1178          executor_type=executor_type)
1179    else:
1180      outputs = gen_functional_ops.partitioned_call(
1181          args=args,
1182          Tout=tout,
1183          f=f,
1184          config_proto=config,
1185          executor_type=executor_type)
1186    return outputs if outputs else None
1187
1188  # The generated binding returns an empty list for functions that don't
1189  # return any Tensors, hence the need to use `create_op` directly.
1190  args = [ops.convert_to_tensor(x) for x in args]
1191  tin_attr = attr_value_pb2.AttrValue(
1192      list=attr_value_pb2.AttrValue.ListValue(
1193          type=[x.dtype.as_datatype_enum for x in args]))
1194  tout_attr = attr_value_pb2.AttrValue(
1195      list=attr_value_pb2.AttrValue.ListValue(type=tout))
1196  func_attr = attr_value_pb2.AttrValue(
1197      func=attr_value_pb2.NameAttrList(name=f.name))
1198  executor_type_attr = attr_value_pb2.AttrValue(
1199      s=compat.as_bytes(executor_type))
1200
1201  # When running in graph mode, the graph and function graphs are optimized
1202  # (i.e. run through grappler) per the session options, so we can disable any
1203  # eager-specific rewriting.
1204  config_proto = attr_value_pb2.AttrValue(s=config)
1205
1206  graph = ops.get_default_graph()
1207  f.add_to_graph(graph)
1208  op_name = "StatefulPartitionedCall" if f.stateful_ops else "PartitionedCall"
1209
1210  # Propagate the attribute indicating the need to compile from function to the
1211  # call itself.
1212  xla_compile_attr = "_XlaMustCompile"
1213  op_attrs = {
1214      "Tin": tin_attr,
1215      "Tout": tout_attr,
1216      "f": func_attr,
1217      "config_proto": config_proto,
1218      "executor_type": executor_type_attr,
1219  }
1220  if xla_compile_attr in f.definition.attr:
1221    op_attrs[xla_compile_attr] = f.definition.attr[xla_compile_attr]
1222  op = graph.create_op(op_name, args, tout, name=op_name, attrs=op_attrs)
1223  outputs = op.outputs
1224  if hasattr(f, "graph"):
1225    _set_read_only_resource_inputs_attr(op, f.graph)
1226    if hasattr(f.graph, "collective_manager_ids_used"):
1227      ops.set_int_list_attr(op, acd.COLLECTIVE_MANAGER_IDS,
1228                            f.graph.collective_manager_ids_used)
1229  return outputs if outputs else op
1230
1231
1232def _set_read_only_resource_inputs_attr(op, func_graph):
1233  """Sets the list of resource inputs which are read-only.
1234
1235  This is used by AutomaticControlDependencies.
1236
1237  Args:
1238    op: PartitionedCall Operation.
1239    func_graph: FuncGraph.
1240  """
1241  read_only_indices = acd.get_read_only_resource_input_indices_graph(func_graph)
1242  ops.set_int_list_attr(op, acd.READ_ONLY_RESOURCE_INPUTS_ATTR,
1243                        read_only_indices)
1244