1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.ops.tensor_array_ops."""
16
17import numpy as np
18
19from tensorflow.core.protobuf import config_pb2
20from tensorflow.python.client import session as session_lib
21from tensorflow.python.data.ops import dataset_ops
22from tensorflow.python.eager import backprop
23from tensorflow.python.eager import context
24from tensorflow.python.eager import def_function
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import errors
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.framework import tensor_spec
31from tensorflow.python.framework import test_util
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import control_flow_ops
34from tensorflow.python.ops import control_flow_util
35from tensorflow.python.ops import data_flow_ops
36from tensorflow.python.ops import gen_data_flow_ops
37from tensorflow.python.ops import gradients_impl
38from tensorflow.python.ops import init_ops
39from tensorflow.python.ops import math_ops
40from tensorflow.python.ops import tensor_array_grad
41from tensorflow.python.ops import tensor_array_ops
42from tensorflow.python.ops import variable_scope
43from tensorflow.python.ops import variables
44import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
45from tensorflow.python.platform import test
46
47
48def _make_converter(tf_dtype):
49  def _converter(x):
50    if tf_dtype == dtypes.string:
51      # In Python3, np.str_ is unicode, while we always want bytes
52      return np.asarray(x).astype("|S")
53    x = np.asarray(x).astype(tf_dtype.as_numpy_dtype)
54    if tf_dtype.is_complex:
55      # Add a non-zero imaginary component to x.
56      x -= 1j * x
57    return x
58  return _converter
59
60
61def _make_ta(size, name, dtype=dtypes.float32, infer_shape=False):
62  return tensor_array_ops.TensorArray(
63      dtype=dtype, tensor_array_name=name, size=size, infer_shape=infer_shape)
64
65
66@test_util.run_all_in_graph_and_eager_modes
67@test_util.with_control_flow_v2
68class TensorArrayTest(test.TestCase):
69
70  @classmethod
71  def setUpClass(cls):
72    super(TensorArrayTest, cls).setUpClass()
73    cls._workers, _ = test.create_local_cluster(num_workers=3, num_ps=0)
74
75  @classmethod
76  def tearDownClass(cls):
77    super(TensorArrayTest, cls).tearDownClass()
78    session_lib.Session.reset(cls._workers[0].target)
79
80  @test_util.run_in_graph_and_eager_modes
81  def testTensorArrayWriteRead(self):
82    with self.session():
83      ta = tensor_array_ops.TensorArray(
84          dtype=dtypes.float32,
85          tensor_array_name="foo",
86          size=3,
87          infer_shape=False)
88
89      w0 = ta.write(0, [[4.0, 5.0]])
90      w1 = w0.write(1, [[1.0]])
91      w2 = w1.write(2, -3.0)
92
93      r0 = w2.read(0)
94      r1 = w2.read(1)
95      r2 = w2.read(2)
96
97      d0, d1, d2 = self.evaluate([r0, r1, r2])
98      self.assertAllEqual([[4.0, 5.0]], d0)
99      self.assertAllEqual([[1.0]], d1)
100      self.assertAllEqual(-3.0, d2)
101
102  def _testTensorArrayWritePack(self, tf_dtype):
103    with self.cached_session():
104      ta = tensor_array_ops.TensorArray(
105          dtype=tf_dtype, tensor_array_name="foo", size=3)
106
107      convert = _make_converter(tf_dtype)
108
109      w0 = ta.write(0, convert([[4.0, 5.0]]))
110      w1 = w0.write(1, convert([[6.0, 7.0]]))
111      w2 = w1.write(2, convert([[8.0, 9.0]]))
112
113      c0 = w2.stack()
114
115      c0 = self.evaluate(c0)
116      self.assertAllEqual(
117          convert([[[4.0, 5.0]], [[6.0, 7.0]], [[8.0, 9.0]]]), c0)
118
119  def _testTensorArrayWritePackMaybeLegacy(self):
120    self._testTensorArrayWritePack(dtypes.float32)
121    self._testTensorArrayWritePack(dtypes.float64)
122    self._testTensorArrayWritePack(dtypes.int32)
123    self._testTensorArrayWritePack(dtypes.int64)
124    self._testTensorArrayWritePack(dtypes.complex64)
125    self._testTensorArrayWritePack(dtypes.complex128)
126    self._testTensorArrayWritePack(dtypes.string)
127
128  def testTensorArrayWritePack(self):
129    self._testTensorArrayWritePackMaybeLegacy()
130
131  def testEmptyTensorArrayPack(self):
132    with self.session():
133      ta = tensor_array_ops.TensorArray(
134          dtype=dtypes.float32, tensor_array_name="foo", size=3)
135
136      empty_element = np.zeros((0, 1), dtype=np.float32)
137      w0 = ta.write(0, empty_element)
138      w1 = w0.write(1, empty_element)
139      w2 = w1.write(2, empty_element)
140
141      c0 = w2.stack()
142
143      c0 = self.evaluate(c0)
144      self.assertAllEqual([3, 0, 1], c0.shape)
145
146  def testTensorArrayWriteConcatInParallel(self):
147    with self.session():
148
149      def _concat_1():
150        ta = tensor_array_ops.TensorArray(
151            dtype=dtypes.int32, size=2, infer_shape=False)
152        w0 = ta.write(0, constant_op.constant([1]))
153        w1 = w0.write(1, constant_op.constant([],
154                                              shape=(0,),
155                                              dtype=dtypes.int32))
156        return w1.concat()
157
158      def _concat_2():
159        ta = tensor_array_ops.TensorArray(
160            dtype=dtypes.int32, size=3, infer_shape=False)
161        w0 = ta.write(0, constant_op.constant([8]))
162        w1 = w0.write(1, constant_op.constant([],
163                                              shape=(0,),
164                                              dtype=dtypes.int32))
165        w2 = w1.write(2, constant_op.constant([9]))
166        return w2.concat()
167
168      def _write(index, output):
169        elements = control_flow_ops.cond(
170            math_ops.less(index, 3), _concat_1, _concat_2)
171        return (index + 1, output.write(index, elements))
172
173      num_iterations = 6
174      init_state = (0,
175                    tensor_array_ops.TensorArray(
176                        dtype=dtypes.int32,
177                        size=num_iterations,
178                        infer_shape=False))
179      _, final_state = control_flow_ops.while_loop(
180          lambda i, _: i < num_iterations, _write, init_state)
181
182      c0 = final_state.concat()
183
184      c0 = self.evaluate(c0)
185      self.assertAllEqual([1, 1, 1, 8, 9, 8, 9, 8, 9], c0)
186
187  def _testTensorArrayWriteConcat(self, tf_dtype):
188    with self.cached_session():
189      ta = tensor_array_ops.TensorArray(
190          dtype=tf_dtype, tensor_array_name="foo", size=3, infer_shape=False)
191
192      convert = _make_converter(tf_dtype)
193
194      w0 = ta.write(0, convert([[4.0, 5.0], [104.0, 105.0], [204.0, 205.0]]))
195      w1 = w0.write(1, convert([[6.0, 7.0], [106.0, 107.0]]))
196      w2 = w1.write(2, convert([[8.0, 9.0]]))
197
198      c0 = w2.concat()
199
200      c0 = self.evaluate(c0)
201      self.assertAllEqual(
202          convert([[4.0, 5.0], [104.0, 105.0], [204.0, 205.0], [6.0, 7.0],
203                   [106.0, 107.0], [8.0, 9.0]]), c0)
204
205  @test_util.deprecated_graph_mode_only
206  def testTensorArrayWriteConcat(self):
207    self._testTensorArrayWriteConcat(dtypes.float32)
208    self._testTensorArrayWriteConcat(dtypes.float64)
209    self._testTensorArrayWriteConcat(dtypes.int32)
210    self._testTensorArrayWriteConcat(dtypes.int64)
211    self._testTensorArrayWriteConcat(dtypes.complex64)
212    self._testTensorArrayWriteConcat(dtypes.complex128)
213    self._testTensorArrayWriteConcat(dtypes.string)
214
215  def _testTensorArrayReadOrPackNotAllValuesAvailableFillsZeros(self):
216    with self.cached_session():
217      ta = tensor_array_ops.TensorArray(
218          dtype=dtypes.float32,
219          tensor_array_name="foo",
220          size=3,
221          element_shape=tensor_shape.TensorShape([1, 2]))
222      self.assertAllEqual([[0.0, 0.0]], self.evaluate(ta.read(0)))
223      self.assertAllEqual([[[0.0, 0.0]], [[4.0, 5.0]], [[0.0, 0.0]]],
224                          self.evaluate(ta.write(1, [[4.0, 5.0]]).stack()))
225      self.assertAllEqual([[0.0, 0.0], [4.0, 5.0], [0.0, 0.0]],
226                          self.evaluate(ta.write(1, [[4.0, 5.0]]).concat()))
227
228  @test_util.run_v1_only("b/122324791")
229  def testTensorArrayReadOrPackNotAllValuesAvailableFillsZeros(self):
230    self._testTensorArrayReadOrPackNotAllValuesAvailableFillsZeros()
231
232  def _testTensorArrayReadOrPackNotAllValuesAvailableInferShapeFillsZeros(self):
233    ta = tensor_array_ops.TensorArray(
234        dtype=dtypes.float32,
235        tensor_array_name="foo",
236        size=3)
237    self.assertAllEqual(
238        [[0.0, 0.0]], self.evaluate(ta.write(1, [[4.0, 5.0]]).read(0)))
239    self.assertAllEqual([[[0.0, 0.0]], [[4.0, 5.0]], [[0.0, 0.0]]],
240                        self.evaluate(ta.write(1, [[4.0, 5.0]]).stack()))
241    self.assertAllEqual([[0.0, 0.0], [4.0, 5.0], [0.0, 0.0]],
242                        self.evaluate(ta.write(1, [[4.0, 5.0]]).concat()))
243
244  @test_util.run_v1_only("b/122324791")
245  def testTensorArrayReadOrPackNotAllValuesAvailableInferShapeFillsZeros(self):
246    self._testTensorArrayReadOrPackNotAllValuesAvailableInferShapeFillsZeros()
247
248  @test_util.run_v1_only("Uses placeholders")
249  def testSkipEagerTensorArrayReadUninitializedInferShapeFillsZeros(self):
250    with self.cached_session() as sess:
251      ta = tensor_array_ops.TensorArray(
252          dtype=dtypes.float32,
253          tensor_array_name="foo",
254          size=3)
255      val = array_ops.placeholder(dtypes.float32)
256      self.assertAllEqual(
257          [[0.0, 0.0]], sess.run(ta.write(1, val).read(0), {val: [[4.0, 5.0]]}))
258
259  def _testTensorArrayUnpackRead(self, tf_dtype):
260    with self.cached_session():
261      convert = _make_converter(tf_dtype)
262
263      ta = _make_ta(3, "foo", dtype=tf_dtype)
264      # Unpack a vector into scalars
265      w0 = ta.unstack(convert([1.0, 2.0, 3.0]))
266      r0 = w0.read(0)
267      r1 = w0.read(1)
268      r2 = w0.read(2)
269
270      d0, d1, d2 = self.evaluate([r0, r1, r2])
271      self.assertAllEqual(convert(1.0), d0)
272      self.assertAllEqual(convert(2.0), d1)
273      self.assertAllEqual(convert(3.0), d2)
274
275      # Unpack a matrix into vectors
276      w1 = ta.unstack(convert([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]]))
277      r0 = w1.read(0)
278      r1 = w1.read(1)
279      r2 = w1.read(2)
280
281      d0, d1, d2 = self.evaluate([r0, r1, r2])
282      self.assertAllEqual(convert([1.0, 1.1]), d0)
283      self.assertAllEqual(convert([2.0, 2.1]), d1)
284      self.assertAllEqual(convert([3.0, 3.1]), d2)
285
286      # Try unpacking an empty matrix, which should not cause an error.
287      w2 = ta.unstack(convert([[], [], []]))
288      r0 = w2.read(0)
289      r1 = w2.read(1)
290      r2 = w2.read(2)
291
292      d0, d1, d2 = self.evaluate([r0, r1, r2])
293      self.assertAllEqual(convert([]), d0)
294      self.assertAllEqual(convert([]), d1)
295      self.assertAllEqual(convert([]), d2)
296
297  def _testTensorArrayUnpackReadMaybeLegacy(self):
298    self._testTensorArrayUnpackRead(dtypes.float32)
299    self._testTensorArrayUnpackRead(dtypes.float64)
300    self._testTensorArrayUnpackRead(dtypes.int32)
301    self._testTensorArrayUnpackRead(dtypes.int64)
302    self._testTensorArrayUnpackRead(dtypes.complex64)
303    self._testTensorArrayUnpackRead(dtypes.complex128)
304    self._testTensorArrayUnpackRead(dtypes.string)
305
306  def testTensorArrayUnpackRead(self):
307    self._testTensorArrayUnpackReadMaybeLegacy()
308
309  def _testTensorArraySplitRead(self, tf_dtype):
310    with self.cached_session():
311      convert = _make_converter(tf_dtype)
312
313      # Split an empty vector
314      ta = _make_ta(3, "foo", dtype=tf_dtype)
315      lengths = constant_op.constant([0, 0, 0])
316      w0 = ta.split(convert([]), lengths=lengths)
317      r0 = w0.read(0)
318      r1 = w0.read(1)
319      r2 = w0.read(2)
320
321      d0, d1, d2 = self.evaluate([r0, r1, r2])
322      self.assertAllEqual(convert([]), d0)
323      self.assertAllEqual(convert([]), d1)
324      self.assertAllEqual(convert([]), d2)
325
326      # Split a vector
327      lengths = constant_op.constant([2, 0, 1])
328      w0 = ta.split(convert([1.0, 2.0, 3.0]), lengths=lengths)
329      r0 = w0.read(0)
330      r1 = w0.read(1)
331      r2 = w0.read(2)
332
333      d0, d1, d2 = self.evaluate([r0, r1, r2])
334      self.assertAllEqual(convert([1.0, 2.0]), d0)
335      self.assertAllEqual(convert([]), d1)
336      self.assertAllEqual(convert([3.0]), d2)
337
338      # Split a matrix
339      lengths = constant_op.constant([2, 0, 1])
340      w0 = ta.split(
341          convert([[1.0, 101.0], [2.0, 201.0], [3.0, 301.0]]), lengths=lengths)
342      r0 = w0.read(0)
343      r1 = w0.read(1)
344      r2 = w0.read(2)
345
346      d0, d1, d2 = self.evaluate([r0, r1, r2])
347      self.assertAllEqual(convert([[1.0, 101.0], [2.0, 201.0]]), d0)
348      self.assertAllEqual(convert([]).reshape(0, 2), d1)
349      self.assertAllEqual(convert([[3.0, 301.0]]), d2)
350
351  @test_util.deprecated_graph_mode_only
352  def testTensorArraySplitRead(self):
353    self._testTensorArraySplitRead(dtypes.float32)
354    self._testTensorArraySplitRead(dtypes.float64)
355    self._testTensorArraySplitRead(dtypes.int32)
356    self._testTensorArraySplitRead(dtypes.int64)
357    self._testTensorArraySplitRead(dtypes.complex64)
358    self._testTensorArraySplitRead(dtypes.complex128)
359    self._testTensorArraySplitRead(dtypes.string)
360
361  @test_util.disable_control_flow_v2("v2 does not support TensorArray.grad.")
362  @test_util.run_v1_only("v2 does not support TensorArray.grad.")
363  def testSkipEagerTensorGradArrayWriteRead(self):
364    with self.session() as session:
365      ta = tensor_array_ops.TensorArray(
366          dtype=dtypes.float32,
367          tensor_array_name="foo",
368          size=3,
369          infer_shape=False)
370      g_ta = ta.grad("grad")
371
372      w0 = ta.write(0, [[4.0, 5.0]])
373      w1 = w0.write(1, [[1.0]])
374      w2 = w1.write(2, -3.0)
375
376      g_w0 = g_ta.write(0, [[5.0, 6.0]])
377      g_w1 = g_w0.write(1, [[2.0]])
378      g_w2 = g_w1.write(2, -2.0)
379
380      r0 = w2.read(0)
381      r1 = w2.read(1)
382      r2 = w2.read(2)
383
384      g_r0 = g_w2.read(0)
385      g_r1 = g_w2.read(1)
386      g_r2 = g_w2.read(2)
387
388      d0, d1, d2, g_d0, g_d1, g_d2 = session.run([r0, r1, r2, g_r0, g_r1, g_r2])
389      self.assertAllEqual([[4.0, 5.0]], d0)
390      self.assertAllEqual([[1.0]], d1)
391      self.assertAllEqual(-3.0, d2)
392      self.assertAllEqual([[5.0, 6.0]], g_d0)
393      self.assertAllEqual([[2.0]], g_d1)
394      self.assertAllEqual(-2.0, g_d2)
395
396  @test_util.deprecated_graph_mode_only
397  def testSkipEagerTensorArrayGradGrad(self):
398    if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
399      self.skipTest("Legacy TensorArray does not support double derivatives.")
400    with self.test_session() as session:
401      x = constant_op.constant(4.0)
402
403      ta = tensor_array_ops.TensorArray(
404          dtype=dtypes.float32,
405          tensor_array_name="foo",
406          size=1,
407          infer_shape=False)
408      w0 = ta.write(0, x)
409      r0 = w0.read(0)
410      y = r0 * r0
411
412      g1 = gradients_impl.gradients(ys=[y], xs=[x])
413      g2 = gradients_impl.gradients(ys=[g1], xs=[x])
414      self.assertAllEqual([2.0], session.run(g2))
415
416  @test_util.disable_control_flow_v2("v2 does not support TensorArray.grad.")
417  @test_util.run_v1_only("v2 does not support TensorArray.grad.")
418  def testSkipEagerTensorGradArrayDynamicWriteRead(self):
419    with self.session() as session:
420      ta = tensor_array_ops.TensorArray(
421          dtype=dtypes.float32,
422          tensor_array_name="foo",
423          size=0,
424          dynamic_size=True,
425          infer_shape=False)
426
427      w0 = ta.write(0, [[4.0, 5.0]])
428      w1 = w0.write(1, [[1.0]])
429      w2 = w1.write(2, -3.0)
430
431      g_ta = w2.grad("grad")  # Get gradient array here so we know the shape
432
433      s = w2.size()
434      g_s = g_ta.size()
435
436      g_w0 = g_ta.write(0, [[5.0, 6.0]])
437      g_w1 = g_w0.write(1, [[2.0]])
438      g_w2 = g_w1.write(2, -2.0)
439
440      r0 = w2.read(0)
441      r1 = w2.read(1)
442      r2 = w2.read(2)
443
444      g_r0 = g_w2.read(0)
445      g_r1 = g_w2.read(1)
446      g_r2 = g_w2.read(2)
447
448      d0, d1, d2, g_d0, g_d1, g_d2, vs, g_vs = session.run(
449          [r0, r1, r2, g_r0, g_r1, g_r2, s, g_s])
450      self.assertAllEqual([[4.0, 5.0]], d0)
451      self.assertAllEqual([[1.0]], d1)
452      self.assertAllEqual(-3.0, d2)
453      self.assertAllEqual([[5.0, 6.0]], g_d0)
454      self.assertAllEqual([[2.0]], g_d1)
455      self.assertAllEqual(-2.0, g_d2)
456      self.assertAllEqual(3, vs)
457      self.assertAllEqual(3, g_vs)
458
459  @test_util.disable_control_flow_v2("v2 does not support TensorArray.grad.")
460  @test_util.run_v1_only("v2 does not support TensorArray.grad.")
461  def testSkipEagerTensorGradAccessTwiceReceiveSameObject(self):
462    with self.session() as session:
463      ta = tensor_array_ops.TensorArray(
464          dtype=dtypes.float32, tensor_array_name="foo", size=3)
465      g_ta_0 = ta.grad("grad")
466      g_ta_1 = ta.grad("grad")
467
468      with ops.control_dependencies([g_ta_0.write(0, [[4.0, 5.0]]).flow]):
469        # Write with one gradient handle, read with another copy of it
470        r1_0 = g_ta_1.read(0)
471
472      t_g_ta_0, t_g_ta_1, d_r1_0 = session.run(
473          [g_ta_0.handle.op, g_ta_1.handle.op, r1_0])
474      self.assertAllEqual(t_g_ta_0, t_g_ta_1)
475      self.assertAllEqual([[4.0, 5.0]], d_r1_0)
476
477  def testTensorArrayWriteWrongIndexOrDataTypeFails(self):
478    with self.session():
479      ta = _make_ta(3, "foo", dtype=dtypes.float32)
480      # TODO(b/129870929): Remove the last 2 checks (runtime checks) after
481      # back back from preferred_dtype= to dtype= in convert_to_tensor.  Also
482      # restrict error check to only TypeError.
483      error_msg_regex = (
484          "("
485          "Expected float32, got 'wrong_type_scalar' of type 'str' instead."
486          "|"
487          "Cannot convert provided value to EagerTensor. Provided value: "
488          "wrong_type_scalar Requested dtype: float"
489          "|"
490          "TensorArray dtype is float.* but Op is trying to write dtype string"
491          "|"
492          "Invalid data types; op elements string but list elements float"
493          ")")
494      with self.assertRaisesRegex((TypeError, errors.InvalidArgumentError),
495                                  error_msg_regex):
496        self.evaluate(ta.write(0, "wrong_type_scalar").flow)
497
498      if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and
499          not context.executing_eagerly()):
500        error_msg = "Trying to modify element -1 in a list with 3 elements."
501      else:
502        error_msg = "index -1"
503      with self.assertRaisesOpError(error_msg):
504        self.evaluate(ta.write(-1, 3.0).flow)
505
506      if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and
507          not context.executing_eagerly()):
508        error_msg = "Trying to modify element 3 in a list with 3 elements"
509      else:
510        error_msg = ("Tried to write to index 3 but array is not "
511                     "resizeable and size is: 3")
512      # Test reading from too large an index
513      with self.assertRaisesOpError(error_msg):
514        self.evaluate(ta.write(3, 3.0).flow)
515
516  def testTensorArrayReadWrongIndexOrDataTypeFails(self):
517    with self.session():
518      ta = _make_ta(3, "foo", dtype=dtypes.float32)
519
520      w0 = ta.write(0, [[4.0, 5.0]])
521
522      # Test reading wrong datatype (only possible when constructing graphs).
523      if (not context.executing_eagerly() and
524          not control_flow_util.ENABLE_CONTROL_FLOW_V2):
525        r0_bad = gen_data_flow_ops.tensor_array_read_v3(
526            handle=w0.handle, index=0, dtype=dtypes.float64, flow_in=w0.flow)
527        with self.assertRaisesOpError(
528            "TensorArray dtype is float but Op requested dtype double."):
529          self.evaluate(r0_bad)
530
531      if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and
532          not context.executing_eagerly()):
533        error_msg = "Trying to access element -1 in a list with 3 elements."
534      else:
535        error_msg = "index -1"
536      # Test reading from a negative index, which is not allowed
537      with self.assertRaisesOpError(error_msg):
538        self.evaluate(ta.read(-1))
539
540      if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and
541          not context.executing_eagerly()):
542        error_msg = "Trying to access element 3 in a list with 3 elements."
543      else:
544        error_msg = "Tried to read from index 3 but array size is: 3"
545      # Test reading from too large an index
546      with self.assertRaisesOpError(error_msg):
547        self.evaluate(ta.read(3))
548
549  @test_util.disable_control_flow_v2("v2 allows multiple writes.")
550  @test_util.run_v1_only("v2 allows multiple writes.")
551  def testSkipEagerTensorArrayWriteMultipleFails(self):
552    with self.session():
553      ta = tensor_array_ops.TensorArray(
554          dtype=dtypes.float32, tensor_array_name="foo", size=3)
555
556      with self.assertRaisesOpError(
557          "Could not write to TensorArray index 2 because "
558          "it has already been written to."):
559        self.evaluate(ta.write(2, 3.0).write(2, 3.0).flow)
560
561  def testTensorArrayConcatIncompatibleShapesFails(self):
562    with self.session():
563      ta = tensor_array_ops.TensorArray(
564          dtype=dtypes.float32,
565          tensor_array_name="foo",
566          size=3,
567          infer_shape=False)
568
569      w1 = ta.write(0, 3.0)
570      w2 = w1.write(1, 4.0)
571      w3 = w2.write(2, [3.0])
572
573      with self.assertRaisesOpError(
574          "Concat saw a scalar shape at index 0 but requires at least vectors"):
575        self.evaluate(w3.concat())
576
577      ta = tensor_array_ops.TensorArray(
578          dtype=dtypes.float32,
579          tensor_array_name="foo",
580          size=3,
581          infer_shape=False)
582
583      w1 = ta.write(0, [3.0])
584      w2 = w1.write(1, [4.0])
585      w3 = w2.write(2, [[3.0]])
586
587      # The exact error messages differ between eager execution and graph
588      # construction as the former bubbles up the error from array_op.concat.
589      error_msg = ("Incompatible ranks"
590                   if control_flow_util.ENABLE_CONTROL_FLOW_V2 and
591                   not context.executing_eagerly() else "shape")
592      with self.assertRaisesRegex(errors.InvalidArgumentError, error_msg):
593        self.evaluate(w3.concat())
594
595  def testTensorArraySplitIncompatibleShapesFails(self):
596    with self.session():
597      in_eager_mode = context.executing_eagerly()
598      ta = _make_ta(3, "foo")
599      with self.assertRaisesOpError(
600          r"Expected lengths to be a vector, received shape: \[\]"):
601        if in_eager_mode:
602          self.evaluate(ta.split([1.0, 2.0, 3.0], 1))
603        else:
604          lengths = array_ops.placeholder(dtypes.int64)
605          ta.split([1.0, 2.0, 3.0], lengths).flow.eval(feed_dict={lengths: 1})
606
607      error_msg = ("Unused values in tensor. Length of tensor: 3 Values used: 1"
608                   if control_flow_util.ENABLE_CONTROL_FLOW_V2 and
609                   not in_eager_mode else
610                   r"Expected sum of lengths to be equal to values.shape\[0\], "
611                   r"but sum of lengths is 1 and value's shape is: \[3\]")
612      with self.assertRaisesOpError(error_msg):
613        self.evaluate(ta.split([1.0, 2.0, 3.0], [1]).flow)
614
615      ta = _make_ta(1, "baz")
616      if control_flow_util.ENABLE_CONTROL_FLOW_V2 and not in_eager_mode:
617        with self.assertRaisesRegex(
618            ValueError, "Shape must be at least rank 1 but is rank 0"):
619          self.evaluate(ta.split(1.0, [1]).flow)
620      else:
621        with self.assertRaisesOpError(
622            r"Expected value to be at least a vector, but received shape: \[\]"
623        ):
624          self.evaluate(ta.split(1.0, [1]).flow)
625
626      if not control_flow_util.ENABLE_CONTROL_FLOW_V2 or in_eager_mode:
627        ta = _make_ta(2, "buz")
628        with self.assertRaisesOpError(
629            r"TensorArray's size is not equal to the size of lengths "
630            r"\(2 vs. 1\), and the TensorArray is not marked as "
631            r"dynamically resizeable"):
632          self.evaluate(ta.split([1.0], [1]).flow)
633
634  def _testTensorArrayWriteGradientAddMultipleAdds(self, dtype):
635    with self.cached_session():
636      ta = tensor_array_ops.TensorArray(
637          dtype=dtype, tensor_array_name="foo", size=3, infer_shape=False)
638      ta_grad = ta.grad("grad")
639
640      c = lambda x: np.asarray(x, dtype=dtype.as_numpy_dtype)
641
642      w0 = ta.write(2, c(3.0))
643      w1 = w0.write(2, c(4.0))
644
645      w0_grad = ta_grad.write(2, c(3.0))
646      w1_grad = w0_grad.write(2, c(4.0))
647      w2_grad = w1_grad.write(2, c(5.0))
648
649      # Assert that aggregation works correctly
650      self.assertAllEqual(c(12.00), w2_grad.read(2))
651
652      # Assert that if multiple_writes_aggregate is not enabled,
653      # multiple writes raise an exception.
654      with self.assertRaisesOpError(
655          r"TensorArray foo_.*: Could not write to TensorArray index 2 because "
656          r"it has already been written to."):
657        self.evaluate(w1.flow)
658
659      # Using differing shapes causes an exception
660      wb0_grad = ta_grad.write(1, c(1.0))
661      wb1_grad = wb0_grad.write(1, c([1.0]))
662
663      with self.assertRaisesOpError(
664          r"Could not aggregate to TensorArray index 1 because the "
665          r"existing shape is \[\] but the new input shape is \[1\]"):
666        self.evaluate(wb1_grad.flow)
667
668  @test_util.disable_control_flow_v2("v2 does not support TensorArray.grad.")
669  @test_util.run_v1_only("v2 does not support TensorArray.grad.")
670  def testSkipEagerTensorArrayWriteGradientAddMultipleAdds(self):
671    for dtype in (dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64,
672                  dtypes.complex64, dtypes.complex128):
673      self._testTensorArrayWriteGradientAddMultipleAdds(dtype)
674
675  @test_util.disable_control_flow_v2("Low level legacy TA op test.")
676  @test_util.run_v1_only("Low level legacy TA op test.")
677  def testSkipEagerTensorArrayGradWithShapeKnownElementShape(self):
678    with self.session() as sess:
679      ta = tensor_array_ops.TensorArray(
680          size=3,
681          dtype=dtypes.float32,
682          element_shape=tensor_shape.TensorShape([2, 3]))
683      handle, flow = data_flow_ops.tensor_array_grad_with_shape(
684          handle=ta.handle,
685          flow_in=ta.flow,
686          shape_to_prepend=tensor_shape.TensorShape([4, 5]),
687          source="source")
688      ta_grad = tensor_array_ops.TensorArray(
689          dtypes.float32, handle=handle, flow=flow)
690      value = array_ops.placeholder(dtypes.float32)
691      ta_grad = ta_grad.write(0, value)
692      read_value = ta_grad.read(0)
693
694      # Make sure shape inference worked.
695      self.assertAllEqual([None, None, 2, 3], read_value.shape.as_list())
696      # Writing with wrong shape should not work.
697      with self.assertRaisesRegex(errors.InvalidArgumentError,
698                                  "Could not write to TensorArray"):
699        fed_value = np.random.random([2, 3])
700        sess.run(read_value, feed_dict={value: fed_value})
701      # Writing with correct shape should work.
702      fed_value = np.random.random([4, 5, 2, 3])
703      self.assertAllClose(fed_value,
704                          sess.run(read_value, feed_dict={value: fed_value}))
705
706  @test_util.disable_control_flow_v2("Low level legacy TA op test.")
707  @test_util.run_v1_only("Low level legacy TA op test.")
708  def testSkipEagerTensorArrayGradWithShapeUnknownElementShape(self):
709    with self.session() as sess:
710      ta = tensor_array_ops.TensorArray(
711          size=3, dtype=dtypes.float32,
712          element_shape=None)  # Note that element_shape is unknown
713      handle, flow = data_flow_ops.tensor_array_grad_with_shape(
714          handle=ta.handle,
715          flow_in=ta.flow,
716          shape_to_prepend=tensor_shape.TensorShape([4, 5]),
717          source="source")
718      ta_grad = tensor_array_ops.TensorArray(
719          dtypes.float32, handle=handle, flow=flow)
720      value = array_ops.placeholder(dtypes.float32)
721      ta_grad = ta_grad.write(0, value)
722      read_value = ta_grad.read(0)
723
724      # Make sure shape inference worked.
725      self.assertIsNone(read_value.shape.ndims)
726      # Write with some shape and check read value.
727      fed_value = np.random.random([4, 5, 7])
728      self.assertAllClose(fed_value,
729                          sess.run(read_value, feed_dict={value: fed_value}))
730
731  def testMultiTensorArray(self):
732    with self.session():
733      h1 = tensor_array_ops.TensorArray(
734          size=1, dtype=dtypes.float32, tensor_array_name="foo")
735      w1 = h1.write(0, 4.0)
736      r1 = w1.read(0)
737
738      h2 = tensor_array_ops.TensorArray(
739          size=1, dtype=dtypes.float32, tensor_array_name="bar")
740
741      w2 = h2.write(0, 5.0)
742      r2 = w2.read(0)
743      r = r1 + r2
744      val = self.evaluate(r)
745      self.assertAllClose(9.0, val)
746
747  def _testTensorArrayGradientWriteReadType(self, dtype):
748    with self.cached_session() as session:
749      ta = tensor_array_ops.TensorArray(
750          dtype=dtypes.as_dtype(dtype),
751          tensor_array_name="foo",
752          size=3,
753          infer_shape=False)
754
755      c = lambda x: np.array(x, dtype=dtype)
756
757      value_0 = constant_op.constant(c([[4.0, 5.0]]))
758      value_1 = constant_op.constant(c(3.0))
759
760      w0 = ta.write(0, value_0)
761      w1 = w0.write(1, value_1)
762      r0 = w1.read(0)
763      r1 = w1.read(1)
764      r0_2 = w1.read(0)
765
766      # Test individual components' gradients
767      grad_just_r0 = gradients_impl.gradients(
768          ys=[r0], xs=[value_0], grad_ys=[c([[2.0, 3.0]])])
769      grad_just_r0_vals = session.run(grad_just_r0)
770      self.assertAllEqual(c([[2.0, 3.0]]), grad_just_r0_vals[0])
771
772      grad_r0_r0_2 = gradients_impl.gradients(
773          ys=[r0, r0_2],
774          xs=[value_0],
775          grad_ys=[c([[2.0, 3.0]]), c([[1.0, -1.0]])])
776      grad_r0_r0_2_vals = session.run(grad_r0_r0_2)
777      self.assertAllEqual(c([[3.0, 2.0]]), grad_r0_r0_2_vals[0])
778
779      grad_just_r1 = gradients_impl.gradients(
780          ys=[r1], xs=[value_1], grad_ys=[c(-2.0)])
781      grad_just_r1_vals = session.run(grad_just_r1)
782      self.assertAllEqual(c(-2.0), grad_just_r1_vals[0])
783
784      # Test combined gradients
785      grad = gradients_impl.gradients(
786          ys=[r0, r0_2, r1],
787          xs=[value_0, value_1],
788          grad_ys=[c([[2.0, 3.0]]), c([[1.0, -1.0]]), c(-2.0)])
789      grad_vals = session.run(grad)
790      self.assertEqual(len(grad_vals), 2)
791      self.assertAllEqual(c([[3.0, 2.0]]), grad_vals[0])
792      self.assertAllEqual(c(-2.0), grad_vals[1])
793
794  @test_util.deprecated_graph_mode_only
795  def testSkipEagerTensorArrayGradientWriteRead(self):
796    for dtype in (np.float32, np.float64, np.complex64, np.complex128):
797      self._testTensorArrayGradientWriteReadType(dtype)
798
799  def _testTensorArrayGradientWritePackConcatAndRead(self):
800    with self.cached_session():
801      ta = tensor_array_ops.TensorArray(
802          dtype=dtypes.float32,
803          tensor_array_name="foo",
804          size=2,
805          clear_after_read=False)
806
807      value_0 = constant_op.constant([-1.0, 1.0])
808      value_1 = constant_op.constant([-10.0, 10.0])
809
810      w0 = ta.write(0, value_0)
811      w1 = w0.write(1, value_1)
812      p0 = w1.stack()
813      r0 = w1.read(0)
814      s0 = w1.concat()
815
816      # Test gradient accumulation between read(0), pack(), and concat()
817      with ops.control_dependencies([p0, r0, s0]):
818        grad_r = gradients_impl.gradients(
819            ys=[p0, r0, s0],
820            xs=[value_0, value_1],
821            grad_ys=[
822                [[2.0, 3.0], [4.0, 5.0]],  # pack gradient
823                [-0.5, 1.5],  # read(0) gradient
824                [20.0, 30.0, 40.0, 50.0]
825            ])  # concat gradient
826      grad_vals = self.evaluate(grad_r)  # 2 + 2 entries
827
828      self.assertAllClose([2.0 - 0.5 + 20.0, 3.0 + 1.5 + 30.0], grad_vals[0])
829      self.assertAllEqual([4.0 + 40.0, 5.0 + 50.0], grad_vals[1])
830
831  @test_util.deprecated_graph_mode_only
832  def testSkipEagerTensorArrayGradientWritePackConcatAndRead(self):
833    self._testTensorArrayGradientWritePackConcatAndRead()
834
835  @test_util.disable_control_flow_v2("v2 does not support clear_after_read.")
836  @test_util.run_v1_only("v2 does not support clear_after_read.")
837  def testTensorArrayReadTwice(self):
838    with self.session():
839      value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]])
840
841      ta_readonce = tensor_array_ops.TensorArray(
842          dtype=dtypes.float32, tensor_array_name="foo", size=2)
843
844      w_readonce = ta_readonce.unstack(value)
845      r0_readonce = w_readonce.read(0)
846
847      with self.assertRaisesOpError(
848          r"Could not read index 0 twice because it was cleared after a "
849          r"previous read \(perhaps try setting clear_after_read = false\?\)"):
850        with ops.control_dependencies([r0_readonce]):
851          self.evaluate(w_readonce.read(0))
852
853      ta_readtwice = tensor_array_ops.TensorArray(
854          dtype=dtypes.float32,
855          tensor_array_name="foo",
856          size=2,
857          clear_after_read=False)
858      w_readtwice = ta_readtwice.unstack(value)
859      r0_readtwice = w_readtwice.read(0)
860      with ops.control_dependencies([r0_readtwice]):
861        r1_readtwice = w_readtwice.read(0)
862
863      self.assertAllEqual([1.0, -1.0], self.evaluate(r1_readtwice))
864
865  def _testTensorArrayGradientUnpackRead(self):
866    with self.cached_session() as session:
867      ta = tensor_array_ops.TensorArray(
868          dtype=dtypes.float32,
869          tensor_array_name="foo",
870          size=2,
871          clear_after_read=False)
872
873      value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]])
874
875      w = ta.unstack(value)
876      r0 = w.read(0)
877      r0_1 = w.read(0)
878      r1 = w.read(1)
879
880      # Test combined gradients + aggregation of read(0)
881      grad = gradients_impl.gradients(
882          ys=[r0, r0_1, r1],
883          xs=[value],
884          grad_ys=[[2.0, 3.0], [-1.5, 1.5], [4.0, 5.0]])
885      grad_vals = session.run(grad)
886
887      self.assertEqual(len(grad_vals), 1)
888      self.assertAllEqual([[2.0 - 1.5, 3.0 + 1.5], [4.0, 5.0]], grad_vals[0])
889
890  @test_util.deprecated_graph_mode_only
891  def testSkipEagerTensorArrayGradientUnpackRead(self):
892    self._testTensorArrayGradientUnpackRead()
893
894  @test_util.deprecated_graph_mode_only
895  def testSkipEagerTensorArrayGradientSplitConcat(self):
896    with self.session() as session:
897      ta = tensor_array_ops.TensorArray(
898          dtype=dtypes.float32, tensor_array_name="foo", size=2,
899          infer_shape=False)
900
901      value = constant_op.constant(
902          [[1.0, -1.0], [10.0, -10.0], [100.0, -100.0]])
903
904      w = ta.split(value, [2, 1])
905      r = w.concat()
906
907      # Test combined gradients
908      grad = gradients_impl.gradients(
909          ys=[r],
910          xs=[value],
911          grad_ys=[[[2.0, -2.0], [20.0, -20.0], [200.0, -200.0]]])
912      grad_vals = session.run(grad)
913
914      self.assertEqual(len(grad_vals), 1)
915      self.assertAllEqual([[2.0, -2.0], [20.0, -20.0], [200.0, -200.0]],
916                          grad_vals[0])
917
918  def _testTensorArrayGradientDynamicUnpackRead(self):
919    with self.cached_session() as session:
920      ta = tensor_array_ops.TensorArray(
921          dtype=dtypes.float32,
922          tensor_array_name="foo",
923          size=0,
924          dynamic_size=True)
925
926      value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]])
927
928      w = ta.unstack(value)
929      r0 = w.read(0)
930      r1 = w.read(1)
931
932      # Test combined gradients + aggregation of read(0)
933      grad = gradients_impl.gradients(
934          ys=[r0, r1], xs=[value], grad_ys=[[2.0, 3.0], [4.0, 5.0]])
935      grad_vals = session.run(grad)
936
937      self.assertEqual(len(grad_vals), 1)
938      self.assertAllEqual([[2.0, 3.0], [4.0, 5.0]], grad_vals[0])
939
940  @test_util.deprecated_graph_mode_only
941  def testSkipEagerTensorArrayGradientDynamicUnpackRead(self):
942    self._testTensorArrayGradientDynamicUnpackRead()
943
944  def testCloseTensorArray(self):
945    with self.session():
946      ta = tensor_array_ops.TensorArray(
947          dtype=dtypes.float32, tensor_array_name="foo", size=3)
948      self.evaluate(ta.close())
949
950  def testSizeTensorArray(self):
951    with self.session():
952      ta = tensor_array_ops.TensorArray(
953          dtype=dtypes.float32, tensor_array_name="foo", size=3)
954      s = ta.size()
955      self.assertAllEqual(3, self.evaluate(s))
956
957  def testWriteCloseTensorArray(self):
958    with self.session():
959      ta = tensor_array_ops.TensorArray(
960          dtype=dtypes.float32,
961          tensor_array_name="foo",
962          size=3,
963          infer_shape=False)
964      w0 = ta.write(0, [[4.0, 5.0]])
965      w1 = w0.write(1, [3.0])
966      self.evaluate(w1.close())  # Expected to run without problems
967
968  def _testWhileLoopWritePackGradients(self, dynamic_size, dtype):
969    np_dtype = dtype.as_numpy_dtype
970    with self.cached_session():
971
972      def func(v0, state0, var):
973        ta = tensor_array_ops.TensorArray(
974            dtype=dtype,
975            tensor_array_name="foo",
976            size=0 if dynamic_size else 3,
977            dynamic_size=dynamic_size)
978        time_0 = array_ops.identity(0)
979
980        def body(time, ta_t, state):
981          sliced = array_ops.slice(
982              v0, begin=array_ops.stack([time, 0]), size=[1, -1])
983          sliced = array_ops.squeeze(sliced)
984          out = sliced + var + state
985          state += sliced
986          ta_t = ta_t.write(time, out)
987          return (time + 1, ta_t, state)
988
989        (unused_0, h_final, unused_2) = control_flow_ops.while_loop(
990            cond=lambda time, unused_1, unused_2: time < 3,
991            body=body,
992            loop_vars=(time_0, ta, state0),
993            shape_invariants=(time_0.get_shape(), tensor_shape.unknown_shape(),
994                              tensor_shape.unknown_shape()),
995            parallel_iterations=3)
996        vout = h_final.stack()
997        return vout
998
999      v0 = array_ops.identity(np.arange(3 * 5, dtype=np_dtype).reshape(3, 5))
1000      state0 = array_ops.identity(np.array([1] * 5, dtype=np_dtype))
1001      init_val = np.arange(100, 105, dtype=np_dtype)
1002      var = variable_scope.get_variable(
1003          "var",
1004          shape=init_val.shape,
1005          dtype=np_dtype,
1006          initializer=init_ops.constant_initializer(init_val))
1007
1008      vout = func(v0, state0, var)
1009      grad_val = -np.arange(3 * 5, dtype=np_dtype).reshape(3, 5)
1010      if context.executing_eagerly():
1011        grad_fn = backprop.gradients_function(func)
1012        v0_grad, state0_grad, var_grad = grad_fn(v0, state0, var, dy=grad_val)
1013      else:
1014        v0_grad = gradients_impl.gradients([vout], [v0], [grad_val])[0]
1015        state0_grad = gradients_impl.gradients([vout], [state0], [grad_val])[0]
1016        var_grad = gradients_impl.gradients([vout], [var], [grad_val])[0]
1017        self.evaluate(variables.global_variables_initializer())
1018
1019      state0_t, var_t, v0_t, vout_t, v0_grad_t, var_grad_t, state0_grad_t = (
1020          self.evaluate(
1021              ([state0, var, v0, vout, v0_grad, var_grad, state0_grad])))
1022      just_v0_grad_t = self.evaluate(v0_grad)
1023
1024      # state = [ state0 | state0 + v0[0] | state0 + v0[0] + v0[1] ]
1025      # vout = [ v0[0] + var + state[0] |
1026      #          v0[1] + var + state[1] |
1027      #          v0[2] + var + state[2] ]
1028      #      = [ v0[0] + var + state0 |
1029      #          v0[1] + var + state0 + v0[0] |
1030      #          v0[2] + var + state0 + v0[0] + v0[1] ]
1031      #
1032      # d(vout[0])/d(v0) = [1 | 0 | 0 ]
1033      # d(vout[1])/d(v0) = [1 | 1 | 0 ]
1034      # d(vout[2])/d(v0) = [1 | 1 | 1 ]
1035      # d(vout)/d(var) = [1 | 1 | 1]
1036      # d(vout)/d(state0) = [ 1 | 1 | 1 ]
1037
1038      state_per_time = np.array(
1039          [state0_t, state0_t + v0_t[0, :], state0_t + v0_t[0, :] + v0_t[1, :]])
1040
1041      # Compare forward prop
1042      self.assertAllClose(v0_t + var_t + state_per_time, vout_t)
1043
1044      # Compare backward prop
1045      expected_v0_grad_t = np.array([
1046          grad_val[0, :] + grad_val[1, :] + grad_val[2, :],
1047          grad_val[1, :] + grad_val[2, :], grad_val[2, :]
1048      ])
1049
1050      self.assertAllEqual(expected_v0_grad_t, v0_grad_t)
1051      self.assertAllEqual(expected_v0_grad_t, just_v0_grad_t)
1052      self.assertAllClose(grad_val.sum(axis=0), var_grad_t)
1053      self.assertAllClose(grad_val.sum(axis=0), state0_grad_t)
1054
1055  def testWhileLoopWritePackGradients(self):
1056    self._testWhileLoopWritePackGradients(
1057        dynamic_size=False, dtype=dtypes.float32)
1058    # TODO(ebrevdo): re-enable when While supports non-float32 gradients.
1059    # self._testWhileLoopWritePackGradients(
1060    #     dynamic_size=False, dtype=tf.int64)
1061
1062  @test_util.run_deprecated_v1
1063  def testSkipEagerWhileLoopDynamicWritePackGradients(self):
1064    self._testWhileLoopWritePackGradients(
1065        dynamic_size=True, dtype=dtypes.float32)
1066
1067  def testGradSerialTwoLoops(self):
1068    with self.session():
1069
1070      def loop(x):
1071        num_steps = 100
1072        acc = tensor_array_ops.TensorArray(
1073            dtype=dtypes.float32,
1074            size=num_steps,
1075            clear_after_read=False,
1076            element_shape=tensor_shape.TensorShape([]))
1077        i = constant_op.constant(0, name="i")
1078
1079        c = lambda i, acc: i < 5
1080
1081        def b(i, acc):
1082          x1 = control_flow_ops.cond(
1083              math_ops.equal(i, 0), lambda: x,
1084              lambda: math_ops.multiply(acc.read(i - 1), 2.0))
1085          return i + 1, acc.write(i, x1)
1086
1087        i1, acc1 = control_flow_ops.while_loop(c, b, [i, acc])
1088
1089        z = constant_op.constant(0.0)
1090
1091        def fn(i, acc):
1092          return i + 1, acc.write(i, z)
1093
1094        _, acc2 = control_flow_ops.while_loop(lambda i, acc: i < num_steps, fn,
1095                                              [i1, acc1])
1096
1097        r = acc2.stack()
1098        return r
1099
1100      x = constant_op.constant(2.0, name="x")
1101      if context.executing_eagerly():
1102        grad = backprop.gradients_function(loop)(x)[0]
1103      else:
1104        grad = gradients_impl.gradients(loop(x), [x])[0]
1105      self.assertAllClose(31.0, self.evaluate(grad))
1106
1107  def testShapeAfterWhileLoop(self):
1108    size = 10
1109    ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=size)
1110    _, ta = control_flow_ops.while_loop(
1111        lambda i, _: i < size,
1112        lambda i, ta: (i + 1, ta.write(i, [[0.]])), [0, ta],
1113        parallel_iterations=1)
1114    self.assertIsNotNone(ta.element_shape.dims)
1115
1116  @test_util.deprecated_graph_mode_only
1117  def testSkipEagerSumOfTwoReadVariablesWithoutRepeatGrad(self):
1118    with self.session() as session:
1119      a = array_ops.identity(
1120          np.arange(
1121              3 * 5, dtype=np.float32).reshape(3, 5) + 1)
1122      b = array_ops.identity(
1123          np.arange(
1124              3 * 5, dtype=np.float32).reshape(3, 5) + 1 + 3 * 5)
1125      ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2)
1126      ta = ta.write(0, a, name="write_a")
1127      ta = ta.write(1, b, name="write_b")
1128      c = (
1129          ta.read(
1130              0, name="read_a_0") +  # a + b
1131          ta.read(
1132              1, name="read_b_0"))
1133      g0 = -(np.arange(3 * 5, dtype=np.float32).reshape(3, 5) + 1)
1134      grad_a = gradients_impl.gradients([c], [a], [g0])[0]  # d(a+b)/da = 1
1135      grad_b = gradients_impl.gradients([c], [b], [g0])[0]  # d(a+b)/db = 1
1136
1137      # Test gradients calculated individually
1138      grad_a_t, = session.run([grad_a])
1139      self.assertAllEqual(grad_a_t, g0)
1140
1141      grad_b_t, = session.run([grad_b])
1142      self.assertAllEqual(grad_b_t, g0)
1143
1144      # Test gradients calculated jointly
1145      joint_grad_a_t, joint_grad_b_t = session.run([grad_a, grad_b])
1146      self.assertAllEqual(joint_grad_a_t, g0)
1147      self.assertAllEqual(joint_grad_b_t, g0)
1148
1149  def _grad_source_for_name(self, name):
1150    return tensor_array_grad._GetGradSource(constant_op.constant(0, name=name))
1151
1152  @test_util.deprecated_graph_mode_only
1153  def testSkipEagerGetGradSource_Invalid(self):
1154    with self.assertRaises(ValueError):
1155      self._grad_source_for_name("")
1156    with self.assertRaises(ValueError):
1157      self._grad_source_for_name("foo")
1158    with self.assertRaises(ValueError):
1159      self._grad_source_for_name("foo/bar")
1160
1161  @test_util.deprecated_graph_mode_only
1162  def testSkipEagerGetGradSource_NoEnclosingScope(self):
1163    self.assertEqual("gradients:0", self._grad_source_for_name("gradients"))
1164    self.assertEqual("gradients_0:0", self._grad_source_for_name("gradients_0"))
1165    self.assertEqual("gradients", self._grad_source_for_name("gradients/foo"))
1166    self.assertEqual("gradients_0",
1167                     self._grad_source_for_name("gradients_0/foo"))
1168    self.assertEqual("gradients",
1169                     self._grad_source_for_name("gradients/foo/bar"))
1170    self.assertEqual("gradients_0",
1171                     self._grad_source_for_name("gradients_0/foo/bar"))
1172
1173  @test_util.deprecated_graph_mode_only
1174  def testSkipEagerGetGradSource_EnclosingScope(self):
1175    self.assertEqual("foo/gradients:0",
1176                     self._grad_source_for_name("foo/gradients"))
1177    self.assertEqual("foo/gradients_0:0",
1178                     self._grad_source_for_name("foo/gradients_0"))
1179    self.assertEqual("foo/gradients",
1180                     self._grad_source_for_name("foo/gradients/bar"))
1181    self.assertEqual("foo/gradients_0",
1182                     self._grad_source_for_name("foo/gradients_0/bar"))
1183    self.assertEqual("foo/bar/gradients",
1184                     self._grad_source_for_name("foo/bar/gradients/baz"))
1185    self.assertEqual("foo/bar/gradients_0",
1186                     self._grad_source_for_name("foo/bar/gradients_0/baz"))
1187
1188  @test_util.deprecated_graph_mode_only
1189  def testSkipEagerGetGradSource_NestedUsesInnermost(self):
1190    self.assertEqual(
1191        "foo/gradients/bar/gradients_0",
1192        self._grad_source_for_name("foo/gradients/bar/gradients_0/baz"))
1193
1194  @test_util.deprecated_graph_mode_only
1195  def testSkipEagerWriteShape(self):
1196    with self.session():
1197      ta = tensor_array_ops.TensorArray(
1198          dtype=dtypes.float32, tensor_array_name="foo", size=3)
1199      c0 = constant_op.constant([4.0, 5.0])
1200      w0 = ta.write(0, c0)
1201      r0 = w0.read(0)
1202      self.assertAllEqual(c0.get_shape(), r0.get_shape())
1203
1204      ta = tensor_array_ops.TensorArray(
1205          dtype=dtypes.float32, tensor_array_name="foo", size=3)
1206      c1 = constant_op.constant([6.0, 7.0])
1207      w1 = w0.write(1, c1)
1208      r0 = w1.read(0)
1209      r1 = w1.read(1)
1210      self.assertAllEqual(c0.get_shape(), r0.get_shape())
1211      self.assertAllEqual(c1.get_shape(), r1.get_shape())
1212
1213      ta = tensor_array_ops.TensorArray(
1214          dtype=dtypes.float32, tensor_array_name="foo", size=3)
1215      c2 = constant_op.constant([4.0, 5.0, 6.0])
1216      with self.assertRaises(ValueError):
1217        w0.write(0, c2)
1218
1219  @test_util.deprecated_graph_mode_only
1220  def testSkipEagerPartlyUnknownShape(self):
1221    with self.session():
1222      ta = tensor_array_ops.TensorArray(
1223          dtype=dtypes.float32, tensor_array_name="foo", size=6)
1224
1225      c0 = array_ops.placeholder(dtypes.float32, [None, None, None, 3])
1226      w0 = ta.write(0, c0)
1227      r0 = w0.read(0)
1228      self.assertAllEqual([None, None, None, 3], r0.get_shape().as_list())
1229
1230      c1 = array_ops.placeholder(dtypes.float32, [None, None, None, 3])
1231      w1 = w0.write(1, c1)
1232      r1 = w1.read(0)
1233      self.assertAllEqual([None, None, None, 3], r1.get_shape().as_list())
1234
1235      # Writing less specific shape (doesn't change type.)
1236      c2 = array_ops.placeholder(dtypes.float32, [None, None, None, None])
1237      w2 = w1.write(2, c2)
1238      r2 = w2.read(0)
1239      self.assertAllEqual([None, None, None, 3], r2.get_shape().as_list())
1240
1241      # Writing more specific shape in one dimension and less specific in
1242      # another.
1243      c3 = array_ops.placeholder(dtypes.float32, [None, None, 2, None])
1244      w3 = w2.write(3, c3)
1245      r3 = w3.read(0)
1246      self.assertAllEqual([None, None, 2, 3], r3.get_shape().as_list())
1247
1248      # Writing partly defined shape using TensorArray.scatter.
1249      c4 = array_ops.placeholder(dtypes.float32, [2, None, 4, 2, 3])
1250      w4 = w3.scatter([4, 5], c4)
1251      r4 = w4.read(0)
1252      self.assertAllEqual([None, 4, 2, 3], r4.get_shape().as_list())
1253
1254      # Writing fully defined shape using TensorArray.split.
1255      c5 = array_ops.placeholder(dtypes.float32, [10, 4, 2, 3])
1256      w5 = w4.split(c5, constant_op.constant([5, 5]))
1257      r5 = w5.read(0)
1258      self.assertAllEqual([5, 4, 2, 3], r5.get_shape().as_list())
1259
1260  def _testUnpackShape(self):
1261    with self.cached_session():
1262      ta = tensor_array_ops.TensorArray(
1263          dtype=dtypes.float32,
1264          tensor_array_name="foo",
1265          size=0,
1266          dynamic_size=True,
1267          infer_shape=True)
1268      value = constant_op.constant(
1269          [[1.0, -1.0], [10.0, -10.0], [100.0, -100.0]])
1270      w0 = ta.unstack(value)
1271      r0 = w0.read(0)
1272      self.assertAllEqual((2,), r0.get_shape())
1273
1274      c1 = constant_op.constant([4.0, 5.0])
1275      w1 = w0.write(3, c1)
1276
1277      if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
1278        # TensorArray v2 does not support clear_after_read.
1279        with self.assertRaisesOpError(
1280            r"Could not read index 0 twice because it was cleared after a "
1281            r"previous read \(perhaps try setting clear_after_read = false\?\)"
1282        ):
1283          with ops.control_dependencies([r0]):
1284            self.evaluate(w1.read(0))
1285
1286      r1 = w1.read(1)
1287      self.assertAllEqual(c1.get_shape(), r1.shape)
1288
1289      c2 = constant_op.constant([4.0, 5.0, 6.0])
1290      with self.assertRaises(ValueError):
1291        w1.write(4, c2)
1292
1293  def testUnpackShape(self):
1294    self._testUnpackShape()
1295
1296  @test_util.deprecated_graph_mode_only
1297  def testSplitShape(self):
1298    with self.session():
1299      ta = tensor_array_ops.TensorArray(
1300          dtype=dtypes.float32,
1301          tensor_array_name="foo",
1302          size=0,
1303          dynamic_size=True,
1304          infer_shape=True)
1305      value = constant_op.constant([[1.0, -1.0], [2.0, -2.0], [3.0, -3.0]])
1306      w0 = ta.split(value, [1, 1, 1])
1307      r0 = w0.read(0)
1308      self.assertAllEqual((1, 2), r0.get_shape())
1309
1310      ta1 = tensor_array_ops.TensorArray(
1311          dtype=dtypes.float32,
1312          tensor_array_name="foo1",
1313          size=0,
1314          dynamic_size=True,
1315          infer_shape=True)
1316      w0 = ta1.split(value, [1, 2])
1317      r0 = w0.read(0)
1318      if context.executing_eagerly():
1319        self.assertEqual((1, 2), r0.get_shape())
1320        self.assertEqual((2, 2), w0.read(1).get_shape())
1321      else:
1322        self.assertEqual(r0.get_shape().ndims, None)
1323        if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
1324          self.assertEqual(
1325              tensor_shape.TensorShape(
1326                  ta1.handle.op.get_attr("element_shape")).ndims, None)
1327
1328  @test_util.deprecated_graph_mode_only
1329  def testSkipEagerWriteUnknownShape(self):
1330    with self.session():
1331      ta = tensor_array_ops.TensorArray(
1332          dtype=dtypes.float32,
1333          tensor_array_name="foo",
1334          size=3,
1335          infer_shape=True)
1336      c0 = array_ops.placeholder(dtypes.float32)
1337      w0 = ta.write(0, c0)
1338      r0 = w0.read(0)
1339      self.assertAllEqual(r0.get_shape(), tensor_shape.unknown_shape())
1340
1341  def _testGradientWhenNotAllComponentsRead(self):
1342    with self.cached_session() as session:
1343      ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2)
1344      x = constant_op.constant([2.0, 3.0])
1345      w = ta.unstack(x)
1346      r0 = w.read(0)
1347      # calculate (dr0/dx0, dr0/dx1).  since r0 = x0, gradients are (1, 0).
1348      grad_r0 = gradients_impl.gradients(ys=[r0], xs=[x], grad_ys=[1.0])
1349      grad_r0_vals = session.run(grad_r0)[0]
1350      self.assertAllEqual(grad_r0_vals, [1.0, 0.0])
1351
1352  @test_util.deprecated_graph_mode_only
1353  def testSkipEagerGradientWhenNotAllComponentsRead(self):
1354    self._testGradientWhenNotAllComponentsRead()
1355
1356  @test_util.deprecated_graph_mode_only
1357  def testSkipEagerWriteButNotAllComponentsReadGrad(self):
1358    with self.cached_session() as session:
1359      x0 = constant_op.constant(5.0)
1360      x1 = constant_op.constant(10.0)
1361      ta = tensor_array_ops.TensorArray(
1362          dtype=dtypes.float32, size=2).write(0, x0).write(1, x1)
1363      r0 = ta.read(0)
1364      # calculate (dr0/dx0, dr0/dx1).  since r0 = x0, gradients are (1, 0).
1365      grad_r0_x1 = gradients_impl.gradients(ys=[r0], xs=[x0, x1], grad_ys=[1.0])
1366      grad_r0_x1_vals = session.run(grad_r0_x1)
1367      self.assertAllEqual(grad_r0_x1_vals, [1.0, 0.0])
1368
1369  def _testTensorArrayUnpackDynamic(self):
1370    with self.cached_session():
1371      ta = tensor_array_ops.TensorArray(
1372          dtype=dtypes.float32, size=3, dynamic_size=True)
1373      x = constant_op.constant([1.0, 2.0, 3.0])
1374      w0 = ta.unstack(x)
1375      w1 = w0.write(3, 4.0)
1376      r = w1.stack()
1377      self.assertAllEqual(np.array([1.0, 2.0, 3.0, 4.0]), self.evaluate(r))
1378      grad = gradients_impl.gradients(ys=[r], xs=[x])
1379      self.assertAllEqual(np.array([1.0, 1.0, 1.0]), self.evaluate(grad)[0])
1380
1381  @test_util.run_deprecated_v1
1382  def testSkipEagerTensorArrayUnpackDynamic(self):
1383    self._testTensorArrayUnpackDynamic()
1384
1385  @test_util.run_deprecated_v1
1386  def testSkipEagerTensorArraySplitDynamic(self):
1387    with self.session():
1388      ta = tensor_array_ops.TensorArray(
1389          dtype=dtypes.float32, size=3, dynamic_size=True)
1390      x = constant_op.constant([1.0, 2.0, 3.0])
1391      w0 = ta.split(x, [1, 1, 1])
1392      w1 = w0.write(3, [4.0])
1393      r = w1.concat()
1394      self.assertAllEqual(np.array([1.0, 2.0, 3.0, 4.0]), self.evaluate(r))
1395      grad = gradients_impl.gradients(ys=[r], xs=[x])
1396      self.assertAllEqual(np.array([1.0, 1.0, 1.0]), self.evaluate(grad)[0])
1397
1398  def testStackShape(self):
1399
1400    @def_function.function
1401    def ta_stack():
1402      ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=3)
1403      x = constant_op.constant([1.0, 2.0, 3.0])
1404      ta = ta.write(0, x)
1405      t = ta.stack()
1406      self.assertEqual(t.shape.as_list(), [3, 3])
1407      return t
1408
1409    ta_stack()
1410
1411  def testReadShape(self):
1412
1413    @def_function.function
1414    def ta_read():
1415      ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=3)
1416      x = constant_op.constant([1.0, 2.0, 3.0])
1417      ta = ta.write(0, x)
1418      t = ta.read(0)
1419      self.assertEqual(t.shape.as_list(), [3])
1420      return t
1421
1422    ta_read()
1423
1424  def testGatherShape(self):
1425
1426    def ta_gather(indices):
1427      ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=3)
1428      x = constant_op.constant([1.0, 2.0, 3.0])
1429      ta = ta.write(0, x)
1430      t = ta.gather(indices)
1431      self.assertEqual(t.shape.as_list(), [first_dim, 3])
1432      return t
1433
1434    # This propagates shape of `indices` when compiling ta_gather.
1435    ta_gather_with_known_indices_shape = def_function.function(ta_gather)
1436    first_dim = 1
1437    ta_gather_with_known_indices_shape([0])
1438
1439    # Here were force the shape of `indices` to be [None] during ta_gather's
1440    # compilation.
1441    ta_gather_with_unknown_indices_shape = def_function.function(
1442        ta_gather,
1443        input_signature=[
1444            tensor_spec.TensorSpec(dtype=dtypes.int32, shape=[None])
1445        ])
1446    first_dim = None
1447    ta_gather_with_unknown_indices_shape([0])
1448
1449  def _testTensorArrayEvalEmpty(self):
1450    with self.cached_session():
1451      ta = tensor_array_ops.TensorArray(
1452          dtype=dtypes.float32, size=0, dynamic_size=False, infer_shape=False)
1453      v2_msg = ("Tried to stack elements of an empty list with "
1454                "non-fully-defined element_shape")
1455      v1_msg = (
1456          "TensorArray has size zero, but element shape <unknown> is not "
1457          "fully defined. Currently only static shapes are supported when "
1458          "packing zero-size TensorArrays.")
1459      with self.assertRaisesOpError(
1460          v2_msg if control_flow_util.ENABLE_CONTROL_FLOW_V2 else v1_msg):
1461        ta.stack().eval()
1462
1463  @test_util.run_deprecated_v1
1464  def testSkipEagerTensorArrayEvalEmpty(self):
1465    self._testTensorArrayEvalEmpty()
1466
1467  # this test is ill-defined for Eager mode --- unpacking an empty tensor
1468  # gives an empty list / there is not equivalent of "mark_used" in Eager
1469  def _testTensorArrayEvalEmptyWithDefault(self):
1470    with self.cached_session():
1471      ta = tensor_array_ops.TensorArray(
1472          dtype=dtypes.float32, size=0, dynamic_size=False, infer_shape=True)
1473      self.assertEqual(0, ta.size().eval())
1474      # Don't actually perform the pack.  This stores the static shape.
1475      if control_flow_util.ENABLE_CONTROL_FLOW_V2:
1476        ta = ta.unstack(array_ops.zeros([0, 3, 5]))
1477      else:
1478        ta.unstack(array_ops.zeros([0, 3, 5])).mark_used()
1479      packed = ta.stack()
1480      concatenated = ta.concat()
1481      self.assertAllEqual([0, 3, 5], self.evaluate(packed).shape)
1482      # Concatenating zero tensors along their first dimension gives a
1483      # first dimension of zero
1484      self.assertAllEqual([0, 5], self.evaluate(concatenated).shape)
1485
1486  @test_util.run_deprecated_v1
1487  def testSkipEagerTensorArrayEvalEmptyWithDefault(self):
1488    self._testTensorArrayEvalEmptyWithDefault()
1489
1490  @test_util.run_deprecated_v1
1491  def testSkipEagerTensorArrayScatterReadAndGradients(self):
1492    with self.session() as session:
1493      ta = tensor_array_ops.TensorArray(
1494          dtype=dtypes.float32,
1495          tensor_array_name="foo",
1496          size=0,
1497          dynamic_size=True)
1498
1499      indices = constant_op.constant([1, 8])
1500      value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]])
1501
1502      w = ta.scatter(indices, value)
1503      r0 = w.read(1)
1504      r1 = w.read(8)
1505
1506      # Test combined gradients + aggregation of read(0)
1507      grad = gradients_impl.gradients(
1508          ys=[r0, r1], xs=[value], grad_ys=[[2.0, 3.0], [4.0, 5.0]])
1509      read_vals, grad_vals = session.run([[r0, r1], grad])
1510
1511      self.assertEqual(len(read_vals), 2)
1512      self.assertEqual(len(grad_vals), 1)
1513      self.assertAllEqual([1.0, -1.0], read_vals[0])
1514      self.assertAllEqual([10.0, -10.0], read_vals[1])
1515      self.assertAllEqual([[2.0, 3.0], [4.0, 5.0]], grad_vals[0])
1516
1517  @test_util.run_deprecated_v1
1518  def testSkipEagerTensorArrayScatterPartialReadAndGradients(self):
1519    with self.session() as session:
1520      ta = tensor_array_ops.TensorArray(
1521          dtype=dtypes.float32,
1522          tensor_array_name="foo",
1523          size=0,
1524          dynamic_size=True)
1525
1526      indices = constant_op.constant([1, 8])
1527      value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]])
1528
1529      w = ta.scatter(indices, value)
1530      r0 = w.read(1)
1531
1532      # Test combined gradients + aggregation of read(0)
1533      grad = gradients_impl.gradients(
1534          ys=[r0], xs=[value], grad_ys=[[2.0, 3.0]])[0]
1535      read_val, grad_val = session.run([r0, grad])
1536
1537      self.assertAllEqual([1.0, -1.0], read_val)
1538      self.assertAllEqual([[2.0, 3.0], [0.0, 0.0]], grad_val)
1539
1540  def testScatterIntoExistingList(self):
1541    ta = tensor_array_ops.TensorArray(
1542        dtype=dtypes.float32, tensor_array_name="foo", size=5)
1543
1544    ta = ta.scatter(indices=[3, 4], value=array_ops.ones([2]))
1545    self.assertAllEqual(ta.stack(), [0., 0., 0., 1., 1.])
1546
1547    ta = ta.scatter(indices=[1], value=array_ops.ones([1]))
1548    self.assertAllEqual(ta.stack(), [0., 1., 0., 1., 1.])
1549
1550    ta = ta.scatter(indices=[0, 2], value=[5., 6.])
1551    self.assertAllEqual(ta.stack(), [5., 1., 6., 1., 1.])
1552
1553  @test_util.run_v1_only("b/118890905")
1554  def testTensorArrayWriteGatherAndGradients(self):
1555    with self.session() as session:
1556      ta = tensor_array_ops.TensorArray(
1557          dtype=dtypes.float32,
1558          tensor_array_name="foo",
1559          size=0,
1560          dynamic_size=True)
1561
1562      def func(values):
1563        indices = constant_op.constant([1, 8])
1564        w = ta.unstack(values)
1565        g = w.gather(indices)
1566        return g
1567
1568      values = constant_op.constant([[1.0 * x, -1.0 * x] for x in range(10)])
1569      g = func(values)
1570      grad_ys = [[[2.0, 3.0], [4.0, 5.0]]]
1571      # Test combined gradients + aggregation of read(0)
1572      if context.executing_eagerly():
1573        g_vals = [g]
1574        grad_vals = backprop.gradients_function(func)(
1575            values, dy=constant_op.constant(grad_ys[0], dtype=dtypes.float32))
1576      else:
1577        grad = gradients_impl.gradients(ys=[g], xs=[values], grad_ys=grad_ys)
1578        g_vals, grad_vals = session.run([[g], grad])
1579
1580      # Gradients for 8 of the 10 unread components are zero.
1581      expected_grad = np.zeros((10, 2))
1582      expected_grad[1] = [2.0, 3.0]
1583      expected_grad[8] = [4.0, 5.0]
1584
1585      self.assertEqual(len(g_vals), 1)
1586      self.assertEqual(len(grad_vals), 1)
1587      self.assertAllEqual([[1.0, -1.0], [8.0, -8.0]], g_vals[0])
1588      self.assertAllEqual(expected_grad, grad_vals[0])
1589
1590  @test_util.disable_control_flow_v2("colocate_with not supported in v2.")
1591  @test_util.run_v1_only("b/120545219")
1592  def testSkipEagerTensorArrayGetsDeviceFromFirstWrite(self):
1593    with ops.device("/job:worker/task:0/cpu:0"):
1594      # this initial device will be ignored.
1595      ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2)
1596    with ops.device("/job:worker/task:1/cpu:0"):
1597      # the first write sets the op's device.
1598      ta = ta.write(0, 1.0)
1599    with ops.device("/job:worker/task:2/cpu:0"):
1600      # subsequent writes do not modify the op's device.
1601      ta = ta.write(1, 1.0)
1602
1603    # The gradient TA will sit on the same device as the forward TA.
1604    ta_grad = ta.grad("grad")
1605    flows = [ta.flow, ta_grad.flow]
1606
1607    # Similar tests for unpack and split
1608    with ops.device("/job:worker/task:0/cpu:0"):
1609      ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=3)
1610    with ops.device("/job:worker/task:1/cpu:0"):
1611      ta = ta.unstack([1.0, 2.0])
1612    with ops.device("/job:worker/task:2/cpu:0"):
1613      ta = ta.write(2, 3.0)
1614    flows.append(ta.flow)
1615
1616    with ops.device("/job:worker/task:0/cpu:0"):
1617      ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2)
1618    with ops.device("/job:worker/task:1/cpu:0"):
1619      ta = ta.split([1.0, 2.0], [1, 1])
1620    flows.append(ta.flow)
1621
1622    session = session_lib.Session(self._workers[0].target)
1623
1624    run_options = config_pb2.RunOptions(
1625        trace_level=config_pb2.RunOptions.FULL_TRACE)
1626    run_metadata = config_pb2.RunMetadata()
1627
1628    session.run(flows, options=run_options, run_metadata=run_metadata)
1629    self.assertTrue(run_metadata.HasField("step_stats"))
1630    dev_stats = {d.device: d.node_stats
1631                 for d in run_metadata.step_stats.dev_stats}
1632    for d in dev_stats:
1633      if "/task:1/" in d:
1634        self.assertTrue(
1635            [s for s in dev_stats[d] if "/TensorArray" in s.node_name])
1636      elif "/host:CPU" not in d:
1637        self.assertFalse(
1638            [s for s in dev_stats[d] if "/TensorArray" in s.node_name])
1639
1640  @test_util.disable_control_flow_v2("colocate_with not supported in v2.")
1641  @test_util.run_v1_only("b/120545219")
1642  def testSkipEagerTensorArrayGetsDeviceFromFirstWriteInWhileLoop(self):
1643    with ops.device("/job:worker/task:0/cpu:0"):
1644      ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2)
1645
1646    def _body(i, ta_i):
1647      with ops.device("/job:worker/task:1/cpu:0"):
1648        return i + 1, ta_i.write(i, constant_op.constant(0.0))
1649
1650    _, ta_out = control_flow_ops.while_loop(
1651        lambda i, ta: i < 2, _body, loop_vars=[0, ta])
1652
1653    session = session_lib.Session(self._workers[0].target)
1654
1655    run_options = config_pb2.RunOptions(
1656        trace_level=config_pb2.RunOptions.FULL_TRACE)
1657    run_metadata = config_pb2.RunMetadata()
1658
1659    session.run(ta_out.flow, options=run_options, run_metadata=run_metadata)
1660    self.assertTrue(run_metadata.HasField("step_stats"))
1661    dev_stats = {d.device: d.node_stats
1662                 for d in run_metadata.step_stats.dev_stats}
1663    for d in dev_stats:
1664      if "/task:1/" in d:
1665        self.assertTrue(
1666            [s for s in dev_stats[d] if "TensorArray" == s.node_name])
1667      else:
1668        self.assertFalse(
1669            [s for s in dev_stats[d] if "TensorArray" == s.node_name])
1670
1671  @test_util.disable_control_flow_v2("colocate_with not supported in v2.")
1672  @test_util.run_v1_only("b/120545219")
1673  def testSkipEagerTensorArrayDisabledColocateWithFirstWriteCall(self):
1674    with ops.device("/job:worker/task:0/cpu:0"):
1675      ta = tensor_array_ops.TensorArray(
1676          dtype=dtypes.float32, size=2, colocate_with_first_write_call=False)
1677
1678    def _body(i, ta_i):
1679      with ops.device("/job:worker/task:1/cpu:0"):
1680        return i + 1, ta_i.write(i, constant_op.constant(0.0))
1681
1682    _, ta_out = control_flow_ops.while_loop(
1683        lambda i, ta: i < 2, _body, loop_vars=[0, ta])
1684
1685    session = session_lib.Session(self._workers[0].target)
1686
1687    run_options = config_pb2.RunOptions(
1688        trace_level=config_pb2.RunOptions.FULL_TRACE)
1689    run_metadata = config_pb2.RunMetadata()
1690
1691    session.run(ta_out.flow, options=run_options, run_metadata=run_metadata)
1692    self.assertTrue(run_metadata.HasField("step_stats"))
1693    dev_stats = {d.device: list(d.node_stats)
1694                 for d in run_metadata.step_stats.dev_stats}
1695    for d in dev_stats:
1696      if "/task:0/" in d and "CPU" in d:  # Skip any GPU node stats
1697        self.assertTrue(
1698            [s for s in dev_stats[d] if "TensorArray" == s.node_name])
1699      else:
1700        self.assertFalse(
1701            [s for s in dev_stats[d] if "TensorArray" == s.node_name])
1702
1703  def testTensorArrayIdentity(self):
1704    with self.session():
1705      ta0 = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2,
1706                                         infer_shape=False)
1707      ta1 = tensor_array_ops.TensorArray(dtype=dtypes.int32, size=4,
1708                                         infer_shape=True)
1709
1710      ta0 = ta0.write(0, 0.)
1711      ta1 = ta1.write(0, 1)
1712
1713      v0 = variable_scope.get_variable(
1714          "v0", shape=(), initializer=init_ops.zeros_initializer())
1715      v1 = variable_scope.get_variable(
1716          "v1", shape=(), initializer=init_ops.zeros_initializer())
1717
1718      with ops.control_dependencies([v0.assign_add(1)]):
1719        ta0 = ta0.identity()
1720
1721      with ops.control_dependencies([v1.assign_add(1)]):
1722        ta1 = ta1.identity()
1723
1724      read0 = ta0.read(0)
1725      read1 = ta1.read(0)
1726
1727      size0 = ta0.size()
1728      size1 = ta1.size()
1729
1730      # Tests correct properties on new TensorArrays.
1731      self.assertEqual(dtypes.float32, ta0.dtype)
1732      self.assertEqual(dtypes.int32, ta1.dtype)
1733      if context.executing_eagerly():
1734        self.assertEqual(tensor_shape.TensorShape([]), read0.get_shape())
1735      else:
1736        self.assertEqual(tensor_shape.unknown_shape(), read0.get_shape())
1737      self.assertEqual(tensor_shape.TensorShape([]), read1.get_shape())
1738
1739      if not context.executing_eagerly():
1740        self.evaluate(variables.global_variables_initializer())
1741
1742      read0_v, read1_v, size0_v, size1_v = self.evaluate((read0, read1, size0,
1743                                                          size1))
1744
1745      # Tests that the control dependencies was added and executed.
1746      self.assertEqual(1, self.evaluate(v0))
1747      self.assertEqual(1, self.evaluate(v1))
1748
1749      # Tests correct TensorArray.
1750      self.assertEqual(read0_v, 0)
1751      self.assertEqual(read1_v, 1)
1752      self.assertEqual(size0_v, 2)
1753      self.assertEqual(size1_v, 4)
1754
1755  @test_util.deprecated_graph_mode_only
1756  def testSkipEagerTensorArrayGradYsInCorrectScope(self):
1757    n_time = 1
1758    n_dim = 1
1759    x = constant_op.constant([[1.42]])
1760    dy = constant_op.constant([[2.42]])
1761
1762    ta = tensor_array_ops.TensorArray(
1763        dtypes.float32, size=n_time, element_shape=[n_dim])
1764    for t in range(n_time):
1765      ta = ta.write(index=t, value=x[t])
1766      y = ta.stack()
1767      # dy is outside of the gradients name scope; tf.gradients must
1768      # wrap it in the correct name scope.
1769      dx, = gradients_impl.gradients(ys=[y], xs=[x], grad_ys=[dy])
1770      with self.cached_session():
1771        vdx, vdy = self.evaluate([dx, dy])
1772      self.assertAllClose(vdx, vdy)
1773
1774  @test_util.deprecated_graph_mode_only
1775  def testSkipEagerTensorArrayInt64GPU(self):
1776    if not test.is_gpu_available():
1777      return
1778    with self.session(force_gpu=True) as sess:
1779      value = array_ops.placeholder(dtypes.int64)
1780      ta = tensor_array_ops.TensorArray(dtype=dtypes.int64, size=2)
1781      ta = ta.scatter([0, 1], value)
1782      r0 = ta.read(0)
1783      r1 = ta.read(1)
1784      v0, v1 = sess.run([r0, r1], feed_dict={value: [-3, 100]})
1785      self.assertAllEqual(v0, -3)
1786      self.assertAllEqual(v1, 100)
1787
1788  def testInferShapeFalseValid(self):
1789    ta = tensor_array_ops.TensorArray(
1790        dtypes.float32, size=3, infer_shape=False, element_shape=[None, 10, 20])
1791    ta = ta.write(0, array_ops.ones([50, 10, 20]))
1792    ta = ta.write(1, array_ops.ones([50, 10, 20]))
1793    ta = ta.write(2, array_ops.ones([1, 10, 20]))
1794    ta = ta.concat()
1795
1796    correct = np.ones([101, 10, 20])
1797
1798    self.assertAllEqual(ta, correct)
1799
1800  def testInferShapeFalseInvalid(self):
1801    ta = tensor_array_ops.TensorArray(
1802        dtypes.float32, size=2, infer_shape=False, element_shape=[None, 10, 20])
1803    ta = ta.write(0, array_ops.ones([50, 10, 20]))
1804
1805    with self.assertRaises(ValueError):
1806      ta = ta.write(1, array_ops.ones([1, 20, 20]))
1807
1808  def testInferShapeTrue(self):
1809    ta = tensor_array_ops.TensorArray(
1810        dtypes.float32, size=3, infer_shape=True, element_shape=[None, 10, 20])
1811    self.assertAllEqual((None, 10, 20), ta.element_shape.as_list())
1812    ta = ta.write(0, array_ops.ones([50, 10, 20]))
1813    self.assertAllEqual((50, 10, 20), ta.element_shape.as_list())
1814    ta = ta.write(1, array_ops.ones([50, 10, 20]))
1815    with self.assertRaises(ValueError):
1816      ta = ta.write(
1817          2, array_ops.ones([1, 10, 20])
1818      )  # Inconsistent shapes: saw (1, 10, 20) but expected (50, 10, 20)
1819
1820  def testStackShapeOnEmpty(self):
1821    ta = tensor_array_ops.TensorArray(
1822        dtypes.float32, size=0, element_shape=(5, 10), dynamic_size=True)
1823    self.assertAllEqual([0, 5, 10], self.evaluate(ta.stack()).shape)
1824
1825  @test_util.run_deprecated_v1
1826  def testSkipEagerStackOnPartiallyDefinedShape(self):
1827    ta = tensor_array_ops.TensorArray(
1828        dtypes.float32, size=0, element_shape=(5, None), dynamic_size=True)
1829    self.assertEqual([None, 5, None], ta.stack().shape.as_list())
1830
1831  def testStackShapeOnStaticSize(self):
1832    ta = tensor_array_ops.TensorArray(dtypes.float32, size=42)
1833    ta = ta.write(0, [0])
1834    self.assertEqual([42, 1], ta.stack().shape.as_list())
1835
1836
1837class TensorArrayBenchmark(test.Benchmark):
1838
1839  def _tensorArrayWriteInWhile(self):
1840    size = 10000
1841    ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=size)
1842    (_, ta) = control_flow_ops.while_loop(
1843        lambda i, _: i < size,
1844        lambda i, ta: (i + 1, ta.write(i, 0.)), [0, ta],
1845        parallel_iterations=1)
1846    return ta.stack()
1847
1848  def _benchmarkWriteInWhile(self):
1849    ops.reset_default_graph()
1850    op = self._tensorArrayWriteInWhile()
1851    self.run_op_benchmark(session_lib.Session(), op)
1852
1853  def benchmarkWriteInWhile(self):
1854    self._benchmarkWriteInWhile()
1855
1856  @test_util.enable_control_flow_v2
1857  def benchmarkWriteInWhileWithControlFlowV2(self):
1858    self._benchmarkWriteInWhile()
1859
1860  def benchmarkWriteInDatasetMapFn(self):
1861    ds = dataset_ops.Dataset.from_tensors(array_ops.zeros([10])).repeat()
1862    ds = ds.map(lambda _: self._tensorArrayWriteInWhile())
1863    op = ds.make_one_shot_iterator().get_next()
1864    self.run_op_benchmark(session_lib.Session(), op)
1865
1866  def benchmarkWriteInDatasetParallelMapFn(self):
1867    ds = dataset_ops.Dataset.from_tensors(array_ops.zeros([10])).repeat()
1868    ds = ds.map(lambda _: self._tensorArrayWriteInWhile(), num_parallel_calls=2)
1869    op = ds.make_one_shot_iterator().get_next()
1870    self.run_op_benchmark(session_lib.Session(), op)
1871
1872
1873if __name__ == "__main__":
1874  test.main()
1875