xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/tensor_array_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TensorArray: a dynamically sized array of Tensors."""
16# Mixture of pep8 and non-pep8 names, so disable pylint bad-name
17# pylint: disable=g-bad-name
18import contextlib
19
20import traceback
21import weakref
22
23import numpy as np
24
25from tensorflow.python.eager import context
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import errors_impl
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import tensor_shape
31from tensorflow.python.framework import tensor_spec
32from tensorflow.python.framework import tensor_util
33from tensorflow.python.framework import type_spec
34from tensorflow.python.ops import array_ops
35from tensorflow.python.ops import control_flow_util
36from tensorflow.python.ops import gen_control_flow_ops
37from tensorflow.python.ops import gen_data_flow_ops
38from tensorflow.python.ops import list_ops
39from tensorflow.python.ops import math_ops
40from tensorflow.python.platform import tf_logging as logging
41from tensorflow.python.util import tf_should_use
42from tensorflow.python.util.tf_export import tf_export
43
44
45# _GraphTensorArray accesses many of the hidden generated ops, but is in
46# fact built to wrap these methods.
47# pylint: disable=protected-access
48class _GraphTensorArray:
49  """Graph-mode implementation of TensorArray."""
50
51  def __init__(self,
52               dtype,
53               size=None,
54               dynamic_size=None,
55               clear_after_read=None,
56               tensor_array_name=None,
57               handle=None,
58               flow=None,
59               infer_shape=True,
60               element_shape=None,
61               colocate_with_first_write_call=True,
62               name=None):
63    """Constructs a graph mode TensorArray.
64
65    Args:
66      dtype: (required) data type of the TensorArray.
67      size: (optional) int32 scalar `Tensor`: the size of the TensorArray.
68        Required if handle is not provided.
69      dynamic_size: (optional) Python bool: If true, writes to the TensorArray
70        can grow the TensorArray past its initial size.  Default: False.
71      clear_after_read: Boolean (optional, default: True).  If True, clear
72        TensorArray values after reading them.  This disables read-many
73        semantics, but allows early release of memory.
74      tensor_array_name: (optional) Python string: the name of the TensorArray.
75        This is used when creating the TensorArray handle.  If this value is
76        set, handle should be None.
77      handle: (optional) A `Tensor` handle to an existing TensorArray.  If this
78        is set, tensor_array_name should be None. Only supported in graph mode.
79      flow: (optional) A float `Tensor` scalar coming from an existing
80        `TensorArray.flow`. Only supported in graph mode.
81      infer_shape: (optional, default: True) If True, shape inference is
82        enabled.  In this case, all elements must have the same shape.
83      element_shape: (optional, default: None) A `TensorShape` object specifying
84        the shape constraints of each of the elements of the TensorArray. Need
85        not be fully defined.
86      colocate_with_first_write_call: If `True`, the TensorArray will be
87        colocated on the same device as the Tensor used on its first write
88        (write operations include `write`, `unstack`, and `split`).  If `False`,
89        the TensorArray will be placed on the device determined by the device
90        context available during its initialization.
91      name: A name for the operation (optional).
92
93    Raises:
94      ValueError: if both handle and tensor_array_name are provided.
95      TypeError: if handle is provided but is not a Tensor.
96    """
97    if handle is not None and tensor_array_name:
98      raise ValueError(
99          "Cannot provide both `handle` and `tensor_array_name` arguments at "
100          "the same time.")
101    if handle is not None and not isinstance(handle, ops.Tensor):
102      raise TypeError(
103          f"Expected `handle` to be a Tensor, but got `{handle}` of type "
104          f"`{type(handle)}` instead.")
105    if handle is None and size is None:
106      raise ValueError(
107          "Argument `size` must be provided if handle is not provided.")
108    if handle is not None and size is not None:
109      raise ValueError("Cannot provide both a `handle` and `size` arguments "
110                       "at the same time.")
111    if handle is not None and element_shape is not None:
112      raise ValueError(
113          "Cannot provide both `handle` and `element_shape` arguments "
114          "at the same time.")
115    if handle is not None and dynamic_size is not None:
116      raise ValueError(
117          "Cannot provide both `handle` and `dynamic_size` arguments "
118          "at the same time.")
119    if handle is not None and clear_after_read is not None:
120      raise ValueError(
121          "Cannot provide both `handle` and `clear_after_read` arguments "
122          "at the same time.")
123
124    if clear_after_read is None:
125      clear_after_read = True
126    self._dynamic_size = dynamic_size or False
127    self._dtype = dtypes.as_dtype(dtype).base_dtype
128
129    # Used to keep track of what tensors the TensorArray should be
130    # colocated with.  We choose to colocate the TensorArray with the
131    # first tensor written to it.
132    self._colocate_with_first_write_call = colocate_with_first_write_call
133    if colocate_with_first_write_call:
134      self._colocate_with = []
135    else:
136      self._colocate_with = None
137
138    # Record the current static shape for the array elements. The element
139    # shape is defined either by `element_shape` or the shape of the tensor
140    # of the first write. If `infer_shape` is true, all writes checks for
141    # shape equality.
142    self._element_shape = [tensor_shape.as_shape(element_shape)]
143    self._infer_shape = infer_shape
144    self._size = size
145    with ops.name_scope(name, "TensorArray", [handle, size, flow]) as scope:
146      if handle is not None:
147        self._handle = handle
148        if flow is None:
149          raise ValueError("flow must not be None if handle is not None.")
150        self._flow = flow
151      else:
152        # Construct the TensorArray with an empty device.  The first
153        # write into the TensorArray from a Tensor with a set device
154        # will retroactively set the device value of this op.
155        def create():
156          """Create the TensorArray op."""
157          return gen_data_flow_ops.tensor_array_v3(
158              dtype=dtype,
159              size=size,
160              element_shape=element_shape,
161              identical_element_shapes=infer_shape,
162              dynamic_size=self._dynamic_size,
163              clear_after_read=clear_after_read,
164              tensor_array_name=tensor_array_name,
165              name=scope)
166
167        if colocate_with_first_write_call:
168          with ops.device(None), ops.colocate_with(None, ignore_existing=True):
169            self._handle, self._flow = create()
170        else:
171          self._handle, self._flow = create()
172
173  @property
174  def flow(self):
175    return self._flow
176
177  @property
178  def dtype(self):
179    return self._dtype
180
181  @property
182  def handle(self):
183    return self._handle
184
185  @property
186  def element_shape(self):
187    return self._element_shape[0]
188
189  def _check_element_shape(self, shape):
190    """Changes the element shape of the array given a shape to merge with.
191
192    Args:
193      shape: A `TensorShape` object to merge with.
194
195    Raises:
196      ValueError: if the provided shape is incompatible with the current
197          element shape of the `TensorArray`.
198    """
199    if not shape.is_compatible_with(self.element_shape):
200      raise ValueError("Inconsistent shapes: saw %s but expected %s " %
201                       (shape, self.element_shape))
202    if self._infer_shape:
203      self._element_shape[0] = self.element_shape.merge_with(shape)
204
205  @contextlib.contextmanager
206  def _maybe_colocate_with(self, value):
207    """Colocate operations with an internal colocation group or `value`.
208
209    Args:
210      value: `Tensor`, the tensor to try to colocate with.
211
212    Yields:
213      Does not yield anything, but the new context is a colocation context.
214
215    If no internal colocation group is set, colocate with `value` and set
216    the internal colocation group to be value.
217    """
218    if not self._colocate_with_first_write_call:
219      yield
220    else:
221      if not self._colocate_with:
222        self._colocate_with.append(value)
223      with ops.colocate_with(self._colocate_with[0]):
224        yield
225
226  def identity(self):
227    """See TensorArray."""
228    flow = array_ops.identity(self._flow)
229    return build_ta_with_new_flow(self, flow)
230
231  def grad(self, source, flow=None, name=None):
232    """See TensorArray."""
233    # tensor_array_grad requires a flow input when forward
234    # TensorArrays are dynamically sized.  This forces the creation
235    # of the grad TensorArray only once the final forward array's size
236    # is fixed.
237    if flow is None:
238      flow = self.flow
239    with ops.name_scope(name, "TensorArrayGrad", [self._handle]):
240      with ops.colocate_with(self._handle):
241        g_handle, unused_flow = gen_data_flow_ops.tensor_array_grad_v3(
242            handle=self._handle, source=source, flow_in=flow, name=name)
243        with ops.control_dependencies([g_handle]):
244          flow = array_ops.identity(flow, name="gradient_flow")
245        g = TensorArray(
246            dtype=self._dtype,
247            handle=g_handle,
248            flow=flow,
249            infer_shape=self._infer_shape,
250            colocate_with_first_write_call=False)
251        # pylint: disable=protected-access
252        g._implementation._element_shape = self._element_shape
253        # pylint: enable=protected-access
254        return g
255
256  def read(self, index, name=None):
257    """See TensorArray."""
258    value = gen_data_flow_ops.tensor_array_read_v3(
259        handle=self._handle,
260        index=index,
261        flow_in=self._flow,
262        dtype=self._dtype,
263        name=name)
264    if self._element_shape:
265      value.set_shape(self._element_shape[0].dims)
266    return value
267
268  def write(self, index, value, name=None):
269    """See TensorArray."""
270    with ops.name_scope(name, "TensorArrayWrite", [self._handle, index, value]):
271      # TODO(b/129870929): Fix after all callers provide proper init dtype.
272      value = ops.convert_to_tensor(
273          value, preferred_dtype=self._dtype, name="value")
274      _check_dtypes(value, self._dtype)
275      self._check_element_shape(value.shape)
276      with self._maybe_colocate_with(value):
277        flow_out = gen_data_flow_ops.tensor_array_write_v3(
278            handle=self._handle,
279            index=index,
280            value=value,
281            flow_in=self._flow,
282            name=name)
283      return build_ta_with_new_flow(self, flow_out)
284
285  def stack(self, name=None):
286    """See TensorArray."""
287    with ops.colocate_with(self._handle):
288      with ops.name_scope(name, "TensorArrayStack", [self._handle]):
289        value = self.gather(math_ops.range(0, self.size()), name=name)
290        if (self.element_shape and not self._dynamic_size and
291            self._size is not None):
292          value.set_shape([tensor_util.constant_value(self._size)] +
293                          self.element_shape.dims)
294        return value
295
296  def gather(self, indices, name=None):
297    """See TensorArray."""
298    if self._element_shape:
299      element_shape = self._element_shape[0]
300    else:
301      element_shape = tensor_shape.unknown_shape(None)
302    value = gen_data_flow_ops.tensor_array_gather_v3(
303        handle=self._handle,
304        indices=indices,
305        flow_in=self._flow,
306        dtype=self._dtype,
307        name=name,
308        element_shape=element_shape)
309    if self.element_shape:
310      value.set_shape([None] + self.element_shape.dims)
311    return value
312
313  def concat(self, name=None):
314    """See TensorArray."""
315    value, _ = gen_data_flow_ops.tensor_array_concat_v3(
316        handle=self._handle,
317        flow_in=self._flow,
318        dtype=self._dtype,
319        name=name,
320        element_shape_except0=self.element_shape[1:])
321    if self.element_shape:
322      value.set_shape([None] + self.element_shape.dims[1:])
323    return value
324
325  @tf_should_use.should_use_result
326  def unstack(self, value, name=None):
327    """See TensorArray."""
328    with ops.name_scope(name, "TensorArrayUnstack", [self._handle, value]):
329      num_elements = array_ops.shape(value)[0]
330      return self.scatter(
331          indices=math_ops.range(0, num_elements), value=value, name=name)
332
333  @tf_should_use.should_use_result
334  def scatter(self, indices, value, name=None):
335    """See TensorArray."""
336    with ops.name_scope(name, "TensorArrayScatter",
337                        [self._handle, value, indices]):
338      # TODO(b/129870929): Fix after all callers provide proper init dtype.
339      value = ops.convert_to_tensor(
340          value, preferred_dtype=self._dtype, name="value")
341      _check_dtypes(value, self._dtype)
342      if not context.executing_eagerly():
343        self._check_element_shape(value.shape[1:])
344      with self._maybe_colocate_with(value):
345        flow_out = gen_data_flow_ops.tensor_array_scatter_v3(
346            handle=self._handle,
347            indices=indices,
348            value=value,
349            flow_in=self._flow,
350            name=name)
351      return build_ta_with_new_flow(self, flow_out)
352
353  @tf_should_use.should_use_result
354  def split(self, value, lengths, name=None):
355    """See TensorArray."""
356    with ops.name_scope(name, "TensorArraySplit",
357                        [self._handle, value, lengths]):
358      value = ops.convert_to_tensor(value, dtype=self._dtype, name="value")
359      with self._maybe_colocate_with(value):
360        lengths_64 = math_ops.cast(lengths, dtypes.int64)
361        if not context.executing_eagerly():
362          clengths = tensor_util.constant_value(lengths_64)
363          if value.shape.dims is not None and clengths is not None:
364            if clengths.shape and clengths.max() == clengths.min():
365              self._check_element_shape(
366                  tensor_shape.TensorShape([clengths[0]
367                                           ]).concatenate(value.shape[1:]))
368        flow_out = gen_data_flow_ops.tensor_array_split_v3(
369            handle=self._handle,
370            value=value,
371            lengths=lengths_64,
372            flow_in=self._flow,
373            name=name)
374      return build_ta_with_new_flow(self, flow_out)
375
376  def size(self, name=None):
377    """See TensorArray."""
378    if not self._dynamic_size and self._size is not None:
379      return ops.convert_to_tensor(self._size, dtype=dtypes.int32)
380    else:
381      return gen_data_flow_ops.tensor_array_size_v3(
382          handle=self._handle, flow_in=self.flow, name=name)
383
384  @tf_should_use.should_use_result
385  def close(self, name=None):
386    """See TensorArray."""
387    return gen_data_flow_ops.tensor_array_close_v3(
388        handle=self._handle, name=name)
389
390
391class _GraphTensorArrayV2:
392  """Graph-mode implementation of TensorArray backed by TensorLists.
393
394  The backing tensor of this TensorArray is a TensorList variant tensor which is
395  stored in the `flow`. The `handle` is always none here. The reason we use the
396  `flow` field and not the `handle` field is to ensure backwards compatibility
397  with legacy control flow.
398  """
399
400  def __init__(self,
401               dtype,
402               size=None,
403               dynamic_size=None,
404               clear_after_read=None,
405               tensor_array_name=None,
406               handle=None,
407               flow=None,
408               infer_shape=True,
409               element_shape=None,
410               colocate_with_first_write_call=True,
411               name=None):
412    """Constructs a graph mode TensorArray.
413
414    Args:
415      dtype: (required) data type of the TensorArray.
416      size: (optional) int32 scalar `Tensor`: the size of the TensorArray.
417        Required if flow is not provided.
418      dynamic_size: (optional) Python bool: If true, writes to the TensorArray
419        can grow the TensorArray past its initial size.  Default: False.
420      clear_after_read: (optional) unused. Not supported in TensorLists.
421      tensor_array_name: (optional) unused.
422      handle: (optional) Must always be None.
423      flow: (optional) A variant `Tensor` scalar for a TensorList.
424      infer_shape: (optional, default: True) If True, shape inference is
425        enabled.  In this case, all elements must have the same shape.
426      element_shape: (optional, default: None) A `TensorShape` object specifying
427        the shape constraints of each of the elements of the TensorArray. Need
428        not be fully defined.
429      colocate_with_first_write_call: (optional). unused.
430      name: (optional) A name for the operation.
431
432    Raises:
433      ValueError: if both handle and tensor_array_name are provided.
434      TypeError: if handle is provided but is not a Tensor.
435    """
436    assert handle is None
437    del handle
438    del clear_after_read
439    del tensor_array_name
440    del colocate_with_first_write_call
441
442    self._dynamic_size = dynamic_size
443    self._size = size
444
445    if (flow is not None and
446        (not isinstance(flow, ops.Tensor) or flow.dtype != dtypes.variant)):
447      raise TypeError(
448          f"Expected `flow` to be a variant tensor, but received `{flow.dtype}` "
449          f"instead.")
450    if flow is None and size is None:
451      raise ValueError("Argument `size` must be provided if argument `flow` "
452                       "is not provided.")
453    if flow is not None and size is not None:
454      raise ValueError("Cannot provide both `flow` and `size` arguments "
455                       "at the same time.")
456    if flow is not None and element_shape is not None:
457      raise ValueError(
458          "Cannot provide both `flow` and `element_shape` arguments"
459          "at the same time.")
460
461    self._dtype = dtypes.as_dtype(dtype).base_dtype
462
463    # Record the current static shape for the array elements. The element
464    # shape is defined either by `element_shape` or the shape of the tensor
465    # of the first write. If `infer_shape` is true, all writes checks for
466    # shape equality.
467    self._element_shape = [tensor_shape.as_shape(element_shape)]
468    self._infer_shape = infer_shape
469    with ops.name_scope(name, "TensorArrayV2", [size, flow]) as scope:
470      if flow is None:
471        self._flow = list_ops.tensor_list_reserve(
472            element_shape=element_shape,
473            num_elements=size,
474            element_dtype=dtype,
475            name=scope)
476      else:
477        self._flow = flow
478
479    # For backwards compatibility.
480    self._colocate_with_first_write_call = None
481    self._colocate_with = None
482
483  @property
484  def flow(self):
485    return self._flow
486
487  @property
488  def dtype(self):
489    return self._dtype
490
491  @property
492  def element_shape(self):
493    return self._element_shape[0]
494
495  @property
496  def handle(self):
497    # We intentionally do not raise an error so that legacy while_loop does not
498    # complain.
499    return None
500
501  def _check_element_shape(self, shape):
502    """Changes the element shape of the array given a shape to merge with.
503
504    Args:
505      shape: A `TensorShape` object to merge with.
506
507    Raises:
508      ValueError: if the provided shape is incompatible with the current
509          element shape of the `TensorArray`.
510    """
511    if not shape.is_compatible_with(self.element_shape):
512      raise ValueError("Inconsistent shapes: saw %s but expected %s " %
513                       (shape, self.element_shape))
514    if self._infer_shape:
515      self._element_shape[0] = self.element_shape.merge_with(shape)
516
517  def identity(self):
518    """See TensorArray."""
519    flow = array_ops.identity(self._flow)
520    return build_ta_with_new_flow(self, flow)
521
522  def grad(self, source, flow=None, name=None):
523    """Not supported."""
524    raise NotImplementedError()
525
526  def read(self, index, name=None):
527    """See TensorArray."""
528    with ops.name_scope(name, "TensorArrayV2Read", [self._flow, index]):
529      value = list_ops.tensor_list_get_item(
530          input_handle=self._flow,
531          index=index,
532          element_dtype=self._dtype,
533          element_shape=self.element_shape,
534          name=name)
535      return value
536
537  def write(self, index, value, name=None):
538    """See TensorArray."""
539    with ops.name_scope(name, "TensorArrayV2Write", [self._flow, index, value]):
540      # TODO(b/129870929): Fix after all callers provide proper init dtype.
541      value = ops.convert_to_tensor(
542          value, preferred_dtype=self._dtype, name="value")
543      _check_dtypes(value, self._dtype)
544      self._check_element_shape(value.shape)
545      flow_out = list_ops.tensor_list_set_item(
546          input_handle=self._flow,
547          index=index,
548          item=value,
549          resize_if_index_out_of_bounds=self._dynamic_size,
550          name=name)
551      return build_ta_with_new_flow(self, flow_out)
552
553  def stack(self, name=None):
554    """See TensorArray."""
555    with ops.name_scope(name, "TensorArrayV2Stack", [self._flow]):
556      # TODO(b/139941163): remove constant_value after changing num_elements to regular input
557      if not self._dynamic_size and self._size is not None:
558        ta_size = tensor_util.constant_value(self._size)
559      else:
560        ta_size = -1
561      value = list_ops.tensor_list_stack(
562          input_handle=self._flow,
563          element_dtype=self._dtype,
564          num_elements=ta_size,
565          element_shape=self.element_shape)
566      return value
567
568  def gather(self, indices, name=None):
569    """See TensorArray."""
570    value = list_ops.tensor_list_gather(
571        input_handle=self._flow,
572        indices=indices,
573        element_dtype=self._dtype,
574        element_shape=self.element_shape,
575        name=name)
576    return value
577
578  def concat(self, name=None):
579    """See TensorArray."""
580    if self.element_shape:
581      element_shape = [None] + self.element_shape.dims[1:]
582    else:
583      element_shape = None
584
585    value = list_ops.tensor_list_concat(
586        input_handle=self._flow,
587        element_dtype=self._dtype,
588        element_shape=element_shape,
589        name=name)
590    return value
591
592  @tf_should_use.should_use_result
593  def unstack(self, value, name=None):
594    """See TensorArray."""
595    with ops.name_scope(name, "TensorArrayUnstack", [self._flow, value]):
596      # TODO(b/129870929): Fix after all callers provide proper init dtype.
597      value = ops.convert_to_tensor(
598          value, preferred_dtype=self._dtype, name="value")
599      _check_dtypes(value, self._dtype)
600      self._check_element_shape(value.shape[1:])
601      flow_out = list_ops.tensor_list_from_tensor(
602          tensor=value, element_shape=value.shape[1:])
603      return build_ta_with_new_flow(self, flow_out)
604
605  @tf_should_use.should_use_result
606  def scatter(self, indices, value, name=None):
607    """See TensorArray."""
608    with ops.name_scope(name, "TensorArrayScatter",
609                        [self._flow, value, indices]):
610      # TODO(b/129870929): Fix after all callers provide proper init dtype.
611      value = ops.convert_to_tensor(
612          value, preferred_dtype=self._dtype, name="value")
613      _check_dtypes(value, self._dtype)
614      self._check_element_shape(value.shape[1:])
615      flow_out = list_ops.tensor_list_scatter(
616          tensor=value,
617          indices=indices,
618          element_shape=self.element_shape,
619          input_handle=self._flow)
620      return build_ta_with_new_flow(self, flow_out)
621
622  @tf_should_use.should_use_result
623  def split(self, value, lengths, name=None):
624    """See TensorArray."""
625    with ops.name_scope(name, "TensorArraySplit", [self._flow, value, lengths]):
626      # TODO(b/129870929): Fix after all callers provide proper init dtype.
627      value = ops.convert_to_tensor(
628          value, preferred_dtype=self._dtype, name="value")
629      _check_dtypes(value, self._dtype)
630      lengths_64 = math_ops.cast(lengths, dtypes.int64)
631      if not context.executing_eagerly():
632        clengths = tensor_util.constant_value(lengths_64)
633        if value.shape.dims is not None and clengths is not None:
634          if clengths.shape and clengths.max() == clengths.min():
635            self._check_element_shape(
636                tensor_shape.TensorShape([clengths[0]
637                                         ]).concatenate(value.shape[1:]))
638      flow_out = list_ops.tensor_list_split(
639          tensor=value,
640          lengths=lengths_64,
641          element_shape=self.element_shape,
642          name=name)
643      return build_ta_with_new_flow(self, flow_out)
644
645  def size(self, name=None):
646    """See TensorArray."""
647    if not self._dynamic_size and self._size is not None:
648      return ops.convert_to_tensor(self._size, dtype=dtypes.int32)
649    else:
650      return list_ops.tensor_list_length(input_handle=self._flow, name=name)
651
652  def close(self, name=None):
653    """See TensorArray."""
654    return gen_control_flow_ops.no_op(name=name)
655
656
657# pylint: enable=protected-access
658
659
660class _EagerTensorArray:
661  """Eager-compatible implementation of TensorArray."""
662
663  def __init__(self,
664               dtype,
665               size=None,
666               dynamic_size=None,
667               clear_after_read=None,
668               tensor_array_name=None,
669               handle=None,
670               flow=None,
671               infer_shape=True,
672               element_shape=None,
673               colocate_with_first_write_call=True,
674               name=None):
675    """Constructs a TensorArray compatible with eager execution.
676
677    Args:
678      dtype: (required) data type of the TensorArray.
679      size: (optional) int32 scalar `Tensor`: the size of the TensorArray.
680        Required if handle is not provided.
681      dynamic_size: (optional) Python bool: If true, writes to the TensorArray
682        can grow the TensorArray past its initial size.  Default: False.
683      clear_after_read: Boolean (optional, default: True).  If True, clear
684        TensorArray values after reading them.  This disables read-many
685        semantics, but allows early release of memory.
686      tensor_array_name: unused.
687      handle: unsupported.
688      flow: unsupported.
689      infer_shape: used for error checking, same semantics as TensorArray.
690      element_shape: used for error checking, same semantics as TensorArray.
691      colocate_with_first_write_call: unsupported.
692      name: unsupported.
693
694    Raises:
695      ValueError: handle or flow are supplied, or if size is not supplied.
696    """
697
698    del (flow, tensor_array_name, name)  # Unused.
699
700    if handle is not None:
701      raise ValueError("TensorArray handles are not supported when eager "
702                       "execution is enabled.")
703    if size is None:
704      raise ValueError("Size must be declared for TensorArrays when eager "
705                       "execution is enabled.")
706
707    # These attributes are not meaningful when eager is enabled, but some
708    # library functions (e.g., those in control_flow_ops.py) access them to
709    # create new tensor arrays; as such, we define them for the sake of
710    # compatibility.
711    self._handle = None
712    # we assign a dummy value to _flow in case other code assumes it to be
713    # a Tensor
714    self._flow = constant_op.constant(0, dtype=dtypes.int32)
715    self._infer_shape = infer_shape
716    self._element_shape = tensor_shape.as_shape(element_shape)
717    self._colocate_with_first_write_call = colocate_with_first_write_call
718
719    self._dtype = dtypes.as_dtype(dtype).base_dtype
720    self._dynamic_size = dynamic_size or False
721    self._clear_after_read = (True
722                              if clear_after_read is None else clear_after_read)
723    self._previously_read_indices = []
724
725    if isinstance(size, ops.EagerTensor):
726      size = size.numpy()
727    self._tensor_array = [None for _ in range(size)]
728
729  @property
730  def flow(self):
731    """For compatibility; flows are not meaningful when eager is enabled."""
732    return self._flow
733
734  @property
735  def dtype(self):
736    return self._dtype
737
738  @property
739  def handle(self):
740    """For compatibility; handles are not meaningful when eager is enabled."""
741    return self._handle
742
743  @property
744  def element_shape(self):
745    return self._element_shape
746
747  def identity(self):
748    """See TensorArray."""
749    return self.parent()
750
751  def grad(self, source, flow=None, name=None):
752    raise NotImplementedError(
753        "TensorArray.grad is not supported when executing eagerly; eager's "
754        "gradient implementation does not use/need this function to compute "
755        "gradients of operations that use TensorArrays.")
756
757  def read(self, index, name=None):
758    """See TensorArray."""
759    del name  # not meaningful when executing eagerly.
760
761    if isinstance(index, ops.EagerTensor):
762      index = index.numpy()
763
764    if index < 0:
765      raise errors_impl.OutOfRangeError(
766          None, None,
767          "Reading from negative indices (index %d) is not allowed." % index)
768
769    if index >= len(self._tensor_array):
770      raise errors_impl.OutOfRangeError(
771          None, None, "Tried to read from index %d but array size is: %d " %
772          (index, len(self._tensor_array)))
773
774    tensor = self._tensor_array[index]
775    if tensor is None:
776      if index in self._previously_read_indices:
777        raise errors_impl.InvalidArgumentError(
778            None, None,
779            "Could not read index %d twice because it was cleared after "
780            "a previous read (perhaps try setting clear_after_read = false?)" %
781            index)
782      else:
783        tensor = self._maybe_zero(index)
784
785    if self._clear_after_read:
786      self._tensor_array[index] = None
787      self._previously_read_indices.append(index)
788    return tensor
789
790  def _write(self, index, value):
791    """Writes `value` into index named by `index`.
792
793    Args:
794      index: 0-D.  int32 scalar with the index to write to.
795      value: N-D.  Tensor of type `dtype`.  The `Tensor` to write to `index`.
796
797    Raises:
798      errors_impl.InvalidArgumentError: `value` dtype does not match dtype.
799      errors_impl.OutOfRangeError: `index` is out of bounds.
800      ValueError: shape of `value` is not consistent with inferred shape.
801    """
802
803    if isinstance(index, ops.EagerTensor):
804      index = index.numpy()
805
806    if index < 0:
807      raise errors_impl.OutOfRangeError(
808          None, None,
809          "Writing to negative indices (index %d) is not allowed." % index)
810
811    size = len(self._tensor_array)
812    if index >= size:
813      if not self._dynamic_size:
814        raise errors_impl.OutOfRangeError(
815            None, None,
816            "Tried to write to index %d but array is not resizeable and size "
817            "is: %d " % (index, size))
818      self._tensor_array.extend(None for _ in range(index - size + 1))
819
820    if not isinstance(value, ops.EagerTensor):
821      # TODO(b/129870929): Fix after all callers provide proper init dtype.
822      value = ops.convert_to_tensor(
823          value, preferred_dtype=self._dtype, name="value")
824
825    if self._dtype != value.dtype:
826      raise errors_impl.InvalidArgumentError(
827          None, None,
828          "TensorArray dtype is %s but Op is trying to write dtype %s " %
829          (self._dtype.name, value.dtype.name))
830
831    if not self._element_shape.is_compatible_with(value.shape):
832      raise ValueError("Incompatible shape for value (%s), expected (%s)" %
833                       (value.shape, self._element_shape))
834
835    if self._infer_shape:
836      self._element_shape = self._element_shape.merge_with(value.shape)
837
838    self._tensor_array[index] = value
839
840  def write(self, index, value, name=None):
841    """See TensorArray."""
842    del name  # not meaningful when executing eagerly.
843    self._write(index, value)
844    return self.parent()
845
846  def _maybe_zero(self, ix):
847    val = self._tensor_array[ix]
848    if val is None:
849      val = self._tensor_array[ix] = array_ops.zeros(
850          shape=self._element_shape, dtype=self._dtype)
851    return val
852
853  def stack(self, name=None):
854    """See TensorArray."""
855    if self._tensor_array:
856      for ix in range(len(self._tensor_array)):
857        self._maybe_zero(ix)
858    if not self._tensor_array and self._element_shape.is_fully_defined():
859      return ops.convert_to_tensor(
860          np.ndarray([0] + self._element_shape), name=name, dtype=self._dtype)
861    else:
862      return ops.convert_to_tensor(
863          self._tensor_array, name=name, dtype=self._dtype)
864
865  def gather(self, indices, name=None):
866    """See TensorArray."""
867    del name  # not meaningful when executing eagerly.
868    if isinstance(indices, ops.EagerTensor):
869      indices = indices.numpy()
870    return array_ops.stack([self._maybe_zero(i) for i in indices])
871
872  def concat(self, name=None):
873    """See TensorArray."""
874    try:
875      return array_ops.concat(
876          [self._maybe_zero(ix) for ix in range(len(self._tensor_array))],
877          0,
878          name=name)
879    except errors_impl.OpError:
880      # Reproduce a subset of the error-handling for graph-mode TensorArrays.
881      shapes = [t.shape for t in self._tensor_array]
882      ndims = [s.ndims for s in shapes]
883      if 0 in ndims:
884        idx = ndims.index(0)
885        raise errors_impl.InvalidArgumentError(
886            None, None, "Concat saw a scalar shape at index %d but requires "
887            "at least vectors." % idx)
888      else:
889        raise
890
891  def unstack(self, value, name=None):
892    """See TensorArray."""
893    tensors = array_ops.unstack(value, name=name)
894    if len(tensors) > len(self._tensor_array) and not self._dynamic_size:
895      raise ValueError(
896          "Cannot unstack %d tensors into a TensorArray of static size %d " %
897          (len(tensors), len(self._tensor_array)))
898    self._tensor_array = tensors
899    return self.parent()
900
901  def scatter(self, indices, value, name=None):
902    """See TensorArray."""
903    del name  # not meaningful when executing eagerly.
904    if isinstance(indices, ops.EagerTensor):
905      indices = indices.numpy()
906    for index, val in zip(indices, array_ops.unstack(value)):
907      self._write(index, val)  # pylint: disable=protected-access
908    return self.parent()
909
910  def split(self, value, lengths, name=None):
911    """See TensorArray."""
912    # TODO(b/129870929): Fix after all callers provide proper init dtype.
913    value = ops.convert_to_tensor(
914        value, preferred_dtype=self._dtype, name="value")
915    _check_dtypes(value, self._dtype)
916    lengths = ops.convert_to_tensor(lengths)
917    sum_lengths = math_ops.reduce_sum(lengths)
918    if lengths.shape.ndims != 1:
919      raise errors_impl.InvalidArgumentError(
920          None, None, "Expected lengths to be a vector, received shape: %s " %
921          lengths.shape.as_list())
922    elif value.shape.ndims == 0:
923      raise errors_impl.InvalidArgumentError(
924          None, None, "Expected value to be at least a vector, "
925          "but received shape: %s " % value.shape.as_list())
926    elif sum_lengths.numpy() != value.shape.as_list()[0]:
927      raise errors_impl.InvalidArgumentError(
928          None, None, "Expected sum of lengths to be equal to "
929          "values.shape[0], but sum of lengths is %d and "
930          "value's shape is: %s " % (sum_lengths.numpy(),
931                                     value.shape.as_list()))
932    elif not self._dynamic_size and lengths.shape[0] != len(self._tensor_array):
933      raise errors_impl.InvalidArgumentError(
934          None, None, "TensorArray's size is not equal to the size of "
935          "lengths (%d vs. %d), and the TensorArray is not marked as "
936          "dynamically resizeable." %
937          (len(self._tensor_array), lengths.shape[0]))
938    else:
939      self._tensor_array = array_ops.split(value, lengths, name=name)
940      return self.parent()
941
942  def size(self, name=None):
943    """See TensorArray."""
944    del name  # not meaningful when executing eagerly.
945    return constant_op.constant(len(self._tensor_array))
946
947  def close(self, name=None):
948    del name  # not meaningful when executing eagerly.
949    del self._tensor_array[:]
950
951
952# TensorArray is designed to hide an underlying implementation object
953# and as such accesses many of that object's hidden fields.
954# pylint: disable=protected-access
955# pylint:disable=line-too-long
956@tf_export("TensorArray")
957class TensorArray:
958  """Class wrapping dynamic-sized, per-time-step, Tensor arrays.
959
960  This class is meant to be used with dynamic iteration primitives such as
961  `while_loop` and `map_fn`.  It supports gradient back-propagation via special
962  "flow" control flow dependencies.
963
964  Note that although the array can be read multiple times and positions can be
965  overwritten, behavior may be undefined when storing multiple references to
966  the same array and clear_after_read is False. In particular, avoid using
967  methods like concat() to convert an intermediate TensorArray to a Tensor,
968  then further modifying the TensorArray, particularly if you need to backprop
969  through it later.
970
971  Example 1: Plain reading and writing.
972
973  >>> ta = tf.TensorArray(tf.float32, size=0, dynamic_size=True, clear_after_read=False)
974  >>> ta = ta.write(0, 10)
975  >>> ta = ta.write(1, 20)
976  >>> ta = ta.write(2, 30)
977  >>>
978  >>> ta.read(0)
979  <tf.Tensor: shape=(), dtype=float32, numpy=10.0>
980  >>> ta.read(1)
981  <tf.Tensor: shape=(), dtype=float32, numpy=20.0>
982  >>> ta.read(2)
983  <tf.Tensor: shape=(), dtype=float32, numpy=30.0>
984  >>> ta.stack()
985  <tf.Tensor: shape=(3,), dtype=float32, numpy=array([10., 20., 30.],
986  dtype=float32)>
987
988  Example 2: Fibonacci sequence algorithm that writes in a loop then returns.
989
990  >>> @tf.function
991  ... def fibonacci(n):
992  ...   ta = tf.TensorArray(tf.float32, size=0, dynamic_size=True)
993  ...   ta = ta.unstack([0., 1.])
994  ...
995  ...   for i in range(2, n):
996  ...     ta = ta.write(i, ta.read(i - 1) + ta.read(i - 2))
997  ...
998  ...   return ta.stack()
999  >>>
1000  >>> fibonacci(7)
1001  <tf.Tensor: shape=(7,), dtype=float32,
1002  numpy=array([0., 1., 1., 2., 3., 5., 8.], dtype=float32)>
1003
1004  Example 3: A simple loop interacting with a `tf.Variable`.
1005
1006  >>> v = tf.Variable(1)
1007  >>> @tf.function
1008  ... def f(x):
1009  ...   ta = tf.TensorArray(tf.int32, size=0, dynamic_size=True)
1010  ...   for i in tf.range(x):
1011  ...     v.assign_add(i)
1012  ...     ta = ta.write(i, v)
1013  ...   return ta.stack()
1014  >>> f(5)
1015  <tf.Tensor: shape=(5,), dtype=int32, numpy=array([ 1,  2,  4,  7, 11],
1016  dtype=int32)>
1017  """
1018
1019  def __init__(self,
1020               dtype,
1021               size=None,
1022               dynamic_size=None,
1023               clear_after_read=None,
1024               tensor_array_name=None,
1025               handle=None,
1026               flow=None,
1027               infer_shape=True,
1028               element_shape=None,
1029               colocate_with_first_write_call=True,
1030               name=None):
1031    """Construct a new TensorArray or wrap an existing TensorArray handle.
1032
1033    A note about the parameter `name`:
1034
1035    The name of the `TensorArray` (even if passed in) is uniquified: each time
1036    a new `TensorArray` is created at runtime it is assigned its own name for
1037    the duration of the run.  This avoids name collisions if a `TensorArray`
1038    is created within a `while_loop`.
1039
1040    Args:
1041      dtype: (required) data type of the TensorArray.
1042      size: (optional) int32 scalar `Tensor`: the size of the TensorArray.
1043        Required if handle is not provided.
1044      dynamic_size: (optional) Python bool: If true, writes to the TensorArray
1045        can grow the TensorArray past its initial size.  Default: False.
1046      clear_after_read: Boolean (optional, default: True).  If True, clear
1047        TensorArray values after reading them.  This disables read-many
1048        semantics, but allows early release of memory.
1049      tensor_array_name: (optional) Python string: the name of the TensorArray.
1050        This is used when creating the TensorArray handle.  If this value is
1051        set, handle should be None.
1052      handle: (optional) A `Tensor` handle to an existing TensorArray.  If this
1053        is set, tensor_array_name should be None. Only supported in graph mode.
1054      flow: (optional) A float `Tensor` scalar coming from an existing
1055        `TensorArray.flow`. Only supported in graph mode.
1056      infer_shape: (optional, default: True) If True, shape inference is
1057        enabled.  In this case, all elements must have the same shape.
1058      element_shape: (optional, default: None) A `TensorShape` object specifying
1059        the shape constraints of each of the elements of the TensorArray. Need
1060        not be fully defined.
1061      colocate_with_first_write_call: If `True`, the TensorArray will be
1062        colocated on the same device as the Tensor used on its first write
1063        (write operations include `write`, `unstack`, and `split`).  If `False`,
1064        the TensorArray will be placed on the device determined by the device
1065        context available during its initialization.
1066      name: A name for the operation (optional).
1067
1068    Raises:
1069      ValueError: if both handle and tensor_array_name are provided.
1070      TypeError: if handle is provided but is not a Tensor.
1071    """
1072    if (context.executing_eagerly() and
1073        (flow is None or flow.dtype != dtypes.variant)):
1074      # It is possible to create a Variant-style TensorArray even in eager mode,
1075      # and this is fine but can have performance implications in eager.
1076      # An example of when this happens is if a tf.function returns a
1077      # TensorArray in its output; its flow variant object is returned to Eager.
1078      # This can be wrapped back up in a Variant-style TensorArray.
1079      implementation = _EagerTensorArray
1080    elif (flow is not None and flow.dtype == dtypes.variant or
1081          control_flow_util.EnableControlFlowV2(ops.get_default_graph())):
1082      implementation = _GraphTensorArrayV2
1083    else:
1084      implementation = _GraphTensorArray
1085    self._implementation = implementation(
1086        dtype,
1087        size=size,
1088        dynamic_size=dynamic_size,
1089        clear_after_read=clear_after_read,
1090        tensor_array_name=tensor_array_name,
1091        handle=handle,
1092        flow=flow,
1093        infer_shape=infer_shape,
1094        element_shape=element_shape,
1095        colocate_with_first_write_call=colocate_with_first_write_call,
1096        name=name)
1097
1098    self._implementation.parent = weakref.ref(self)
1099
1100  @property
1101  def flow(self):
1102    """The flow `Tensor` forcing ops leading to this TensorArray state."""
1103    return self._implementation._flow
1104
1105  @property
1106  def dtype(self):
1107    """The data type of this TensorArray."""
1108    return self._implementation._dtype
1109
1110  @property
1111  def handle(self):
1112    """The reference to the TensorArray."""
1113    return self._implementation.handle
1114
1115  @property
1116  def element_shape(self):
1117    """The `tf.TensorShape` of elements in this TensorArray."""
1118    return self._implementation.element_shape
1119
1120  @property
1121  def dynamic_size(self):
1122    """Python bool; if `True` the TensorArray can grow dynamically."""
1123    return self._implementation._dynamic_size
1124
1125  @property
1126  def _infer_shape(self):
1127    # TODO(slebedev): consider making public or changing TensorArrayStructure
1128    # to access _implementation directly. Note that dynamic_size is also
1129    # only used by TensorArrayStructure.
1130    return self._implementation._infer_shape
1131
1132  def identity(self):
1133    """Returns a TensorArray with the same content and properties.
1134
1135    Returns:
1136      A new TensorArray object with flow that ensures the control dependencies
1137      from the contexts will become control dependencies for writes, reads, etc.
1138      Use this object for all subsequent operations.
1139    """
1140    return self._implementation.identity()
1141
1142  def grad(self, source, flow=None, name=None):
1143    return self._implementation.grad(source, flow=flow, name=name)
1144
1145  def read(self, index, name=None):
1146    """Read the value at location `index` in the TensorArray.
1147
1148    Args:
1149      index: 0-D.  int32 tensor with the index to read from.
1150      name: A name for the operation (optional).
1151
1152    Returns:
1153      The tensor at index `index`.
1154    """
1155    return self._implementation.read(index, name=name)
1156
1157  @tf_should_use.should_use_result(warn_in_eager=True)
1158  def write(self, index, value, name=None):
1159    """Write `value` into index `index` of the TensorArray.
1160
1161    Args:
1162      index: 0-D.  int32 scalar with the index to write to.
1163      value: N-D.  Tensor of type `dtype`.  The Tensor to write to this index.
1164      name: A name for the operation (optional).
1165
1166    Returns:
1167      A new TensorArray object with flow that ensures the write occurs.
1168      Use this object for all subsequent operations.
1169
1170    Raises:
1171      ValueError: if there are more writers than specified.
1172    """
1173    return self._implementation.write(index, value, name=name)
1174
1175  def stack(self, name=None):
1176    """Return the values in the TensorArray as a stacked `Tensor`.
1177
1178    All of the values must have been written and their shapes must all match.
1179    If input shapes have rank-`R`, then output shape will have rank-`(R+1)`.
1180
1181    For example:
1182
1183
1184    >>> ta = tf.TensorArray(tf.int32, size=3)
1185    >>> ta.write(0, tf.constant([1, 2]))
1186    >>> ta.write(1, tf.constant([3, 4]))
1187    >>> ta.write(2, tf.constant([5, 6]))
1188    >>> ta.stack()
1189    <tf.Tensor: shape=(3, 2), dtype=int32, numpy=
1190    array([[1, 2],
1191           [3, 4],
1192           [5, 6]], dtype=int32)>
1193
1194
1195    Args:
1196      name: A name for the operation (optional).
1197
1198    Returns:
1199      All the tensors in the TensorArray stacked into one tensor.
1200    """
1201    return self._implementation.stack(name=name)
1202
1203  def gather(self, indices, name=None):
1204    """Return selected values in the TensorArray as a packed `Tensor`.
1205
1206    All of selected values must have been written and their shapes
1207    must all match.
1208
1209    Args:
1210      indices: A `1-D` `Tensor` taking values in `[0, max_value)`.  If the
1211        `TensorArray` is not dynamic, `max_value=size()`.
1212      name: A name for the operation (optional).
1213
1214    Returns:
1215      The tensors in the `TensorArray` selected by `indices`, packed into one
1216      tensor.
1217    """
1218    return self._implementation.gather(indices, name=name)
1219
1220  def concat(self, name=None):
1221    """Return the values in the TensorArray as a concatenated `Tensor`.
1222
1223    All of the values must have been written, their ranks must match, and
1224    and their shapes must all match for all dimensions except the first.
1225
1226    Args:
1227      name: A name for the operation (optional).
1228
1229    Returns:
1230      All the tensors in the TensorArray concatenated into one tensor.
1231    """
1232    return self._implementation.concat(name=name)
1233
1234  @tf_should_use.should_use_result
1235  def unstack(self, value, name=None):
1236    """Unstack the values of a `Tensor` in the TensorArray.
1237
1238    If input value shapes have rank-`R`, then the output TensorArray will
1239    contain elements whose shapes are rank-`(R-1)`.
1240
1241    Args:
1242      value: (N+1)-D.  Tensor of type `dtype`.  The Tensor to unstack.
1243      name: A name for the operation (optional).
1244
1245    Returns:
1246      A new TensorArray object with flow that ensures the unstack occurs.
1247      Use this object for all subsequent operations.
1248
1249    Raises:
1250      ValueError: if the shape inference fails.
1251    """
1252    return self._implementation.unstack(value, name=name)
1253
1254  @tf_should_use.should_use_result
1255  def scatter(self, indices, value, name=None):
1256    """Scatter the values of a `Tensor` in specific indices of a `TensorArray`.
1257
1258    Args:
1259      indices: A `1-D` `Tensor` taking values in `[0, max_value)`.  If the
1260        `TensorArray` is not dynamic, `max_value=size()`.
1261      value: (N+1)-D.  Tensor of type `dtype`.  The Tensor to unpack.
1262      name: A name for the operation (optional).
1263
1264    Returns:
1265      A new TensorArray object with flow that ensures the scatter occurs.
1266      Use this object for all subsequent operations.
1267
1268    Raises:
1269      ValueError: if the shape inference fails.
1270    """
1271    return self._implementation.scatter(indices, value, name=name)
1272
1273  @tf_should_use.should_use_result
1274  def split(self, value, lengths, name=None):
1275    """Split the values of a `Tensor` into the TensorArray.
1276
1277    Args:
1278      value: (N+1)-D.  Tensor of type `dtype`.  The Tensor to split.
1279      lengths: 1-D.  int32 vector with the lengths to use when splitting `value`
1280        along its first dimension.
1281      name: A name for the operation (optional).
1282
1283    Returns:
1284      A new TensorArray object with flow that ensures the split occurs.
1285      Use this object for all subsequent operations.
1286
1287    Raises:
1288      ValueError: if the shape inference fails.
1289    """
1290    return self._implementation.split(value, lengths, name=name)
1291
1292  def size(self, name=None):
1293    """Return the size of the TensorArray."""
1294    return self._implementation.size(name=name)
1295
1296  @tf_should_use.should_use_result
1297  def close(self, name=None):
1298    """Close the current TensorArray."""
1299    return self._implementation.close(name=name)
1300
1301
1302def build_ta_with_new_flow(old_ta, flow):
1303  """Builds a TensorArray with a new `flow` tensor."""
1304  # Sometimes we get old_ta as the implementation, sometimes it's the
1305  # TensorArray wrapper object.
1306  impl = (old_ta._implementation if isinstance(old_ta, TensorArray) else old_ta)
1307
1308  if not context.executing_eagerly():
1309    if (not isinstance(impl, _GraphTensorArrayV2) and
1310        control_flow_util.EnableControlFlowV2(ops.get_default_graph())):
1311      raise NotImplementedError("Attempting to build a graph-mode TF2-style "
1312                                "TensorArray from either an eager-mode "
1313                                "TensorArray or a TF1-style TensorArray.  "
1314                                "This is not currently supported.  You may be "
1315                                "attempting to capture a TensorArray "
1316                                "inside a tf.function or tf.data map function. "
1317                                "Instead, construct a new TensorArray inside "
1318                                "the function.")
1319  new_ta = TensorArray(
1320      dtype=impl.dtype,
1321      handle=impl.handle,
1322      flow=flow,
1323      infer_shape=impl._infer_shape,
1324      colocate_with_first_write_call=impl._colocate_with_first_write_call)
1325  new_impl = new_ta._implementation
1326  new_impl._dynamic_size = impl._dynamic_size
1327  new_impl._size = impl._size
1328  new_impl._colocate_with = impl._colocate_with
1329  new_impl._element_shape = impl._element_shape  # Share _element_shape.
1330  return new_ta
1331
1332
1333# pylint: enable=protected-access
1334
1335
1336def _check_dtypes(value, dtype):
1337  if value.dtype != dtype:
1338    logging.error("Error: Input value {} has dtype {}, but expected dtype {}.  "
1339                  "This leads to undefined behavior and will be an error "
1340                  "in future versions of TensorFlow.  Traceback:\n{}".format(
1341                      value, str(value.dtype), str(dtype),
1342                      "".join(traceback.format_stack())))
1343
1344
1345@tf_export("TensorArraySpec")
1346@type_spec.register("tf.TensorArraySpec")
1347class TensorArraySpec(type_spec.TypeSpec):
1348  """Type specification for a `tf.TensorArray`."""
1349
1350  __slots__ = ["_element_shape", "_dtype", "_dynamic_size", "_infer_shape"]
1351
1352  value_type = property(lambda self: TensorArray)
1353
1354  def __init__(self,
1355               element_shape=None,
1356               dtype=dtypes.float32,
1357               dynamic_size=False,
1358               infer_shape=True):
1359    """Constructs a type specification for a `tf.TensorArray`.
1360
1361    Args:
1362      element_shape: The shape of each element in the `TensorArray`.
1363      dtype: Data type of the `TensorArray`.
1364      dynamic_size: Whether the `TensorArray` can grow past its initial size.
1365      infer_shape: Whether shape inference is enabled.
1366    """
1367    self._element_shape = tensor_shape.as_shape(element_shape)
1368    self._dtype = dtypes.as_dtype(dtype)
1369    self._dynamic_size = dynamic_size
1370    self._infer_shape = infer_shape
1371
1372  def is_subtype_of(self, other):
1373    # pylint: disable=protected-access
1374    return (isinstance(other, TensorArraySpec) and
1375            self._dtype == other._dtype and
1376            self._dynamic_size == other._dynamic_size)
1377
1378  def most_specific_common_supertype(self, others):
1379    """Returns the most specific supertype of `self` and `others`.
1380
1381    Args:
1382      others: A Sequence of `TypeSpec`.
1383
1384    Returns `None` if a supertype does not exist.
1385    """
1386    # pylint: disable=protected-access
1387    if not all(isinstance(other, TensorArraySpec) for other in others):
1388      return False
1389
1390    common_shape = self._element_shape.most_specific_common_supertype(
1391        other._element_shape for other in others)
1392    if common_shape is None:
1393      return None
1394
1395    if not all(self._dtype == other._dtype for other in others):
1396      return None
1397
1398    if not all(self._dynamic_size == other._dynamic_size for other in others):
1399      return None
1400
1401    infer_shape = self._infer_shape and all(
1402        other._infer_shape for other in others)
1403
1404    return TensorArraySpec(common_shape, self._dtype, self._dynamic_size,
1405                           infer_shape)
1406
1407  def is_compatible_with(self, other):
1408    # pylint: disable=protected-access
1409    if not isinstance(other, type_spec.TypeSpec):
1410      other = type_spec.type_spec_from_value(other)
1411
1412    # Note: we intentionally exclude infer_shape in this check.
1413    return (isinstance(other, TensorArraySpec) and
1414            self._dtype.is_compatible_with(other._dtype) and
1415            self._element_shape.is_compatible_with(other._element_shape) and
1416            self._dynamic_size == other._dynamic_size)
1417
1418  def _serialize(self):
1419    return (self._element_shape, self._dtype, self._dynamic_size,
1420            self._infer_shape)
1421
1422  @property
1423  def _component_specs(self):
1424    return [tensor_spec.TensorSpec([], dtypes.variant)]
1425
1426  def _to_components(self, value):
1427    if not isinstance(value, TensorArray):
1428      raise TypeError("Expected value to be a TensorArray, but got: `{}`".format(
1429          type(value)))
1430    if value.flow is not None and value.flow.dtype == dtypes.variant:
1431      return [value.flow]
1432    else:
1433      # Convert to a TF2-style TensorArray.
1434      # TODO(ebrevdo): Add an "_as_variant" method to TensorArray class, or
1435      # "implementation / as_variant" arg to TensorArray constructor.
1436      with ops.name_scope("convert_tensor_array"):
1437        flow = list_ops.tensor_list_from_tensor(
1438            tensor=value.stack(), element_shape=value.element_shape)
1439      return [flow]
1440
1441  def _from_components(self, tensor_list):
1442    # This will return a TF2 Graph-style TensorArray because tensor_list[0] is
1443    # a variant object.  size == -1 implies unknown size.
1444    ret = TensorArray(
1445        dtype=self._dtype,
1446        flow=tensor_list[0],
1447        dynamic_size=self._dynamic_size,
1448        infer_shape=self._infer_shape)
1449    ret._implementation._element_shape = [self._element_shape]  # pylint: disable=protected-access
1450    return ret
1451
1452  @staticmethod
1453  def from_value(value):
1454    if not isinstance(value, TensorArray):
1455      raise TypeError("Expected value to be a TensorArray, but got: `{}`".format(
1456          type(value)))
1457
1458    return TensorArraySpec(
1459        dtype=value.dtype,
1460        element_shape=value.element_shape,
1461        dynamic_size=value.dynamic_size,
1462        infer_shape=value._infer_shape)  # pylint: disable=protected-access
1463
1464  def _to_legacy_output_types(self):
1465    return self._dtype
1466
1467  def _to_legacy_output_shapes(self):
1468    # Sneak the dynamic_size and infer_shape values into the legacy shape.
1469    return (tensor_shape.TensorShape([self._dynamic_size, self._infer_shape
1470                                     ]).concatenate(self._element_shape))
1471
1472  def _to_legacy_output_classes(self):
1473    return TensorArray
1474
1475
1476# Register the TypeSpec for TensorArray.  If TensorArray is updated to be a
1477# CompositeTensor, then this registration can be deleted.
1478type_spec.register_type_spec_from_value_converter(
1479    TensorArray, TensorArraySpec.from_value, allow_subclass=True)
1480