xref: /aosp_15_r20/external/tensorflow/tensorflow/python/kernel_tests/nn_ops/rnn_cell_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for RNN cells."""
16
17import itertools
18import os
19
20from absl.testing import parameterized
21import numpy as np
22
23from tensorflow.core.protobuf import config_pb2
24from tensorflow.python.checkpoint import checkpoint as trackable_utils
25from tensorflow.python.eager import context
26from tensorflow.python.eager import def_function
27from tensorflow.python.framework import constant_op
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import errors_impl
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import random_seed
32from tensorflow.python.framework import tensor_shape
33from tensorflow.python.framework import tensor_spec
34from tensorflow.python.framework import test_util
35from tensorflow.python.ops import   array_ops
36from tensorflow.python.ops import control_flow_ops
37from tensorflow.python.ops import gen_rnn_ops
38from tensorflow.python.ops import gradients_impl
39from tensorflow.python.ops import init_ops
40from tensorflow.python.ops import math_ops
41from tensorflow.python.ops import rnn
42from tensorflow.python.ops import rnn_cell
43from tensorflow.python.ops import rnn_cell_impl
44from tensorflow.python.ops import state_ops
45from tensorflow.python.ops import tensor_array_ops
46from tensorflow.python.ops import variable_scope
47from tensorflow.python.ops import variables as variables_lib
48from tensorflow.python.platform import test
49from tensorflow.python.platform import tf_logging
50from tensorflow.python.saved_model import load
51from tensorflow.python.saved_model import save
52from tensorflow.python.trackable import autotrackable
53from tensorflow.python.util import nest
54
55
56class Plus1RNNCell(rnn_cell.RNNCell):
57  """RNN Cell generating (output, new_state) = (input + 1, state + 1)."""
58
59  @property
60  def output_size(self):
61    return 5
62
63  @property
64  def state_size(self):
65    return 5
66
67  def __call__(self, input_, state, scope=None):
68    return (input_ + 1, state + 1)
69
70
71class DummyMultiDimensionalLSTM(rnn_cell.RNNCell):
72  """LSTM Cell generating (output, new_state) = (input + 1, state + 1).
73
74  The input to this cell may have an arbitrary number of dimensions that follow
75  the preceding 'Time' and 'Batch' dimensions.
76  """
77
78  def __init__(self, dims):
79    """Initialize the Multi-dimensional LSTM cell.
80
81    Args:
82      dims: tuple that contains the dimensions of the output of the cell,
83      without including 'Time' or 'Batch' dimensions.
84    """
85    if not isinstance(dims, tuple):
86      raise TypeError("The dimensions passed to DummyMultiDimensionalLSTM "
87                      "should be a tuple of ints.")
88    self._dims = dims
89    self._output_size = tensor_shape.TensorShape(self._dims)
90    self._state_size = (tensor_shape.TensorShape(self._dims),
91                        tensor_shape.TensorShape(self._dims))
92
93  @property
94  def output_size(self):
95    return self._output_size
96
97  @property
98  def state_size(self):
99    return self._state_size
100
101  def __call__(self, input_, state, scope=None):
102    h, c = state
103    return (input_ + 1, (h + 1, c + 1))
104
105
106class NestedRNNCell(rnn_cell.RNNCell):
107  """RNN Cell generating (output, new_state) = (input + 1, state + 1).
108
109  The input, output and state of this cell is a tuple of two tensors.
110  """
111
112  @property
113  def output_size(self):
114    return (5, 5)
115
116  @property
117  def state_size(self):
118    return (6, 6)
119
120  def __call__(self, input_, state, scope=None):
121    h, c = state
122    x, y = input_
123    return ((x + 1, y + 1), (h + 1, c + 1))
124
125
126class TestStateSaver(object):
127
128  def __init__(self, batch_size, state_size):
129    self._batch_size = batch_size
130    self._state_size = state_size
131    self.saved_state = {}
132
133  def state(self, name):
134
135    if isinstance(self._state_size, dict):
136      state_size = self._state_size[name]
137    else:
138      state_size = self._state_size
139    if isinstance(state_size, int):
140      state_size = (state_size,)
141    elif isinstance(state_size, tuple):
142      pass
143    else:
144      raise TypeError("state_size should either be an int or a tuple")
145
146    return array_ops.zeros((self._batch_size,) + state_size)
147
148  def save_state(self, name, state):
149    self.saved_state[name] = state
150    return array_ops.identity(state)
151
152  @property
153  def batch_size(self):
154    return self._batch_size
155
156  @property
157  def state_size(self):
158    return self._state_size
159
160
161class TestStateSaverWithCounters(TestStateSaver):
162  """Class wrapper around TestStateSaver.
163
164  A dummy class used for testing of static_state_saving_rnn. It helps test if
165  save_state and state functions got called same number of time when we
166  evaluate output of rnn cell and state or either of them separately. It
167  inherits from the TestStateSaver and adds the counters for calls of functions.
168  """
169
170  @test_util.run_v1_only("b/124229375")
171  def __init__(self, batch_size, state_size):
172    super(TestStateSaverWithCounters, self).__init__(batch_size, state_size)
173    self._num_state_calls = variables_lib.VariableV1(0)
174    self._num_save_state_calls = variables_lib.VariableV1(0)
175
176  def state(self, name):
177    with ops.control_dependencies(
178        [state_ops.assign_add(self._num_state_calls, 1)]):
179      return super(TestStateSaverWithCounters, self).state(name)
180
181  def save_state(self, name, state):
182    with ops.control_dependencies([state_ops.assign_add(
183        self._num_save_state_calls, 1)]):
184      return super(TestStateSaverWithCounters, self).save_state(name, state)
185
186  @property
187  def num_state_calls(self):
188    return self._num_state_calls
189
190  @property
191  def num_save_state_calls(self):
192    return self._num_save_state_calls
193
194
195class RNNTest(test.TestCase):
196
197  def setUp(self):
198    self._seed = 23489
199    np.random.seed(self._seed)
200
201  @test_util.run_v1_only("b/124229375")
202  def testInvalidSequenceLengthShape(self):
203    cell = Plus1RNNCell()
204    inputs = [array_ops.placeholder(dtypes.float32, shape=(3, 4))]
205    with self.assertRaisesRegex(ValueError, "must be a vector"):
206      rnn.static_rnn(cell, inputs, dtype=dtypes.float32, sequence_length=4)
207
208  @test_util.run_v1_only("b/124229375")
209  def testRNN(self):
210    cell = Plus1RNNCell()
211    batch_size = 2
212    input_size = 5
213    max_length = 8  # unrolled up to this length
214    inputs = max_length * [
215        array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size))
216    ]
217    outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
218    self.assertEqual(len(outputs), len(inputs))
219    for out, inp in zip(outputs, inputs):
220      self.assertEqual(out.get_shape(), inp.get_shape())
221      self.assertEqual(out.dtype, inp.dtype)
222
223    with self.session() as sess:
224      input_value = np.random.randn(batch_size, input_size)
225      values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value})
226
227      # Outputs
228      for v in values[:-1]:
229        self.assertAllClose(v, input_value + 1.0)
230
231      # Final state
232      self.assertAllClose(values[-1],
233                          max_length * np.ones(
234                              (batch_size, input_size), dtype=np.float32))
235
236  @test_util.run_v1_only("b/124229375")
237  def testDropout(self):
238    cell = Plus1RNNCell()
239    full_dropout_cell = rnn_cell.DropoutWrapper(
240        cell, input_keep_prob=1e-6, seed=0)
241    self.assertIn("cell", full_dropout_cell._trackable_children())
242    self.assertIs(full_dropout_cell._trackable_children()["cell"], cell)
243    batch_size = 2
244    input_size = 5
245    max_length = 8
246    inputs = max_length * [
247        array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size))
248    ]
249    with variable_scope.variable_scope("share_scope"):
250      outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
251    with variable_scope.variable_scope("drop_scope"):
252      dropped_outputs, _ = rnn.static_rnn(
253          full_dropout_cell, inputs, dtype=dtypes.float32)
254    self.assertEqual(len(outputs), len(inputs))
255    for out, inp in zip(outputs, inputs):
256      self.assertEqual(out.get_shape().as_list(), inp.get_shape().as_list())
257      self.assertEqual(out.dtype, inp.dtype)
258
259    with self.session() as sess:
260      input_value = np.random.randn(batch_size, input_size)
261      values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value})
262      full_dropout_values = sess.run(
263          dropped_outputs, feed_dict={
264              inputs[0]: input_value
265          })
266
267      for v in values[:-1]:
268        self.assertAllClose(v, input_value + 1.0)
269      for d_v in full_dropout_values[:-1]:  # Add 1.0 to dropped_out (all zeros)
270        self.assertAllClose(d_v, np.ones_like(input_value))
271
272  @test_util.run_v1_only("b/124229375")
273  def testDynamicCalculation(self):
274    cell = Plus1RNNCell()
275    sequence_length = array_ops.placeholder(dtypes.int64)
276    batch_size = 2
277    input_size = 5
278    max_length = 8
279    inputs = max_length * [
280        array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size))
281    ]
282    with variable_scope.variable_scope("drop_scope"):
283      dynamic_outputs, dynamic_state = rnn.static_rnn(
284          cell, inputs, sequence_length=sequence_length, dtype=dtypes.float32)
285    self.assertEqual(len(dynamic_outputs), len(inputs))
286
287    with self.session() as sess:
288      input_value = np.random.randn(batch_size, input_size)
289      dynamic_values = sess.run(
290          dynamic_outputs,
291          feed_dict={
292              inputs[0]: input_value,
293              sequence_length: [2, 3]
294          })
295      dynamic_state_value = sess.run(
296          [dynamic_state],
297          feed_dict={
298              inputs[0]: input_value,
299              sequence_length: [2, 3]
300          })
301
302      # outputs are fully calculated for t = 0, 1
303      for v in dynamic_values[:2]:
304        self.assertAllClose(v, input_value + 1.0)
305
306      # outputs at t = 2 are zero for entry 0, calculated for entry 1
307      self.assertAllClose(dynamic_values[2],
308                          np.vstack((np.zeros((input_size)),
309                                     1.0 + input_value[1, :])))
310
311      # outputs at t = 3+ are zero
312      for v in dynamic_values[3:]:
313        self.assertAllEqual(v, np.zeros_like(input_value))
314
315      # the final states are:
316      #  entry 0: the values from the calculation at t=1
317      #  entry 1: the values from the calculation at t=2
318      self.assertAllEqual(dynamic_state_value[0],
319                          np.vstack((1.0 * (1 + 1) * np.ones((input_size)),
320                                     1.0 * (2 + 1) * np.ones((input_size)))))
321
322  def _testScope(self, factory, prefix="prefix", use_outer_scope=True):
323    with self.session(graph=ops.Graph()):
324      if use_outer_scope:
325        with variable_scope.variable_scope(prefix) as scope:
326          factory(scope)
327      else:
328        factory(prefix)
329
330      # check that all the variables names starts
331      # with the proper scope.
332      variables_lib.global_variables_initializer()
333      all_vars = variables_lib.global_variables()
334      prefix = prefix or "rnn"
335      scope_vars = [v for v in all_vars if v.name.startswith(prefix + "/")]
336      tf_logging.info("RNN with scope: %s (%s)" %
337                      (prefix, "scope" if use_outer_scope else "str"))
338      for v in scope_vars:
339        tf_logging.info(v.name)
340      self.assertEqual(len(scope_vars), len(all_vars))
341
342  @test_util.run_v1_only("b/124229375")
343  def testScope(self):
344
345    def factory(scope):
346      cell = Plus1RNNCell()
347      batch_size = 2
348      input_size = 5
349      max_length = 8  # unrolled up to this length
350      inputs = max_length * [
351          array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size))
352      ]
353      return rnn.static_rnn(cell, inputs, dtype=dtypes.float32, scope=scope)
354
355    self._testScope(factory, use_outer_scope=True)
356    self._testScope(factory, use_outer_scope=False)
357    self._testScope(factory, prefix=None, use_outer_scope=False)
358
359
360class LSTMTest(test.TestCase):
361
362  def setUp(self):
363    self._seed = 23489
364    np.random.seed(self._seed)
365
366  def testDType(self):
367    # Test case for GitHub issue 16228
368    # Not passing dtype in constructor results in default float32
369    lstm = rnn_cell.LSTMCell(10)
370    input_tensor = array_ops.ones([10, 50])
371    lstm.build(input_tensor.get_shape())
372    self.assertEqual(lstm._bias.dtype.base_dtype, dtypes.float32)
373
374    # Explicitly pass dtype in constructor
375    for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
376      lstm = rnn_cell.LSTMCell(10, dtype=dtype)
377      input_tensor = array_ops.ones([10, 50])
378      lstm.build(input_tensor.get_shape())
379      self.assertEqual(lstm._bias.dtype.base_dtype, dtype)
380
381  @test_util.run_v1_only("b/124229375")
382  def testNoProjNoSharding(self):
383    num_units = 3
384    input_size = 5
385    batch_size = 2
386    max_length = 8
387    with self.session(graph=ops.Graph()) as sess:
388      initializer = init_ops.random_uniform_initializer(
389          -0.01, 0.01, seed=self._seed)
390      cell = rnn_cell.LSTMCell(
391          num_units, initializer=initializer, state_is_tuple=False)
392      inputs = max_length * [
393          array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size))
394      ]
395      outputs, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
396      self.assertEqual(len(outputs), len(inputs))
397      for out in outputs:
398        self.assertEqual(out.get_shape().as_list(), [batch_size, num_units])
399
400      variables_lib.global_variables_initializer().run()
401      input_value = np.random.randn(batch_size, input_size)
402      sess.run(outputs, feed_dict={inputs[0]: input_value})
403
404  @test_util.run_v1_only("b/124229375")
405  def testCellClipping(self):
406    num_units = 3
407    input_size = 5
408    batch_size = 2
409    max_length = 8
410    with self.session(graph=ops.Graph()) as sess:
411      initializer = init_ops.random_uniform_initializer(
412          -0.01, 0.01, seed=self._seed)
413      cell = rnn_cell.LSTMCell(
414          num_units,
415          use_peepholes=True,
416          cell_clip=0.0,
417          initializer=initializer,
418          state_is_tuple=False)
419      inputs = max_length * [
420          array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size))
421      ]
422      outputs, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
423      self.assertEqual(len(outputs), len(inputs))
424      for out in outputs:
425        self.assertEqual(out.get_shape().as_list(), [batch_size, num_units])
426
427      variables_lib.global_variables_initializer().run()
428      input_value = np.random.randn(batch_size, input_size)
429      values = sess.run(outputs, feed_dict={inputs[0]: input_value})
430
431    for value in values:
432      # if cell c is clipped to 0, tanh(c) = 0 => m==0
433      self.assertAllEqual(value, np.zeros((batch_size, num_units)))
434
435  @test_util.run_v1_only("b/124229375")
436  def testNoProjNoShardingSimpleStateSaver(self):
437    num_units = 3
438    input_size = 5
439    batch_size = 2
440    max_length = 8
441    with self.session(graph=ops.Graph()) as sess:
442      initializer = init_ops.random_uniform_initializer(
443          -0.01, 0.01, seed=self._seed)
444      state_saver = TestStateSaver(batch_size, 2 * num_units)
445      cell = rnn_cell.LSTMCell(
446          num_units,
447          use_peepholes=False,
448          initializer=initializer,
449          state_is_tuple=False)
450      inputs = max_length * [
451          array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size))
452      ]
453      with variable_scope.variable_scope("share_scope"):
454        outputs, state = rnn.static_state_saving_rnn(
455            cell, inputs, state_saver=state_saver, state_name="save_lstm")
456      self.assertEqual(len(outputs), len(inputs))
457      for out in outputs:
458        self.assertEqual(out.get_shape().as_list(), [batch_size, num_units])
459
460      variables_lib.global_variables_initializer().run()
461      input_value = np.random.randn(batch_size, input_size)
462      (last_state_value, saved_state_value) = sess.run(
463          [state, state_saver.saved_state["save_lstm"]],
464          feed_dict={
465              inputs[0]: input_value
466          })
467      self.assertAllEqual(last_state_value, saved_state_value)
468
469  @test_util.run_v1_only("b/124229375")
470  def testNoProjNoShardingTupleStateSaver(self):
471    num_units = 3
472    input_size = 5
473    batch_size = 2
474    max_length = 8
475    with self.session(graph=ops.Graph()) as sess:
476      initializer = init_ops.random_uniform_initializer(
477          -0.01, 0.01, seed=self._seed)
478      state_saver = TestStateSaver(batch_size, num_units)
479      cell = rnn_cell.LSTMCell(
480          num_units,
481          use_peepholes=False,
482          initializer=initializer,
483          state_is_tuple=True)
484      inputs = max_length * [
485          array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size))
486      ]
487      with variable_scope.variable_scope("share_scope"):
488        outputs, state = rnn.static_state_saving_rnn(
489            cell, inputs, state_saver=state_saver, state_name=("c", "m"))
490      self.assertEqual(len(outputs), len(inputs))
491      for out in outputs:
492        self.assertEqual(out.get_shape().as_list(), [batch_size, num_units])
493
494      variables_lib.global_variables_initializer().run()
495      input_value = np.random.randn(batch_size, input_size)
496      last_and_saved_states = sess.run(
497          state + (state_saver.saved_state["c"], state_saver.saved_state["m"]),
498          feed_dict={
499              inputs[0]: input_value
500          })
501      self.assertEqual(4, len(last_and_saved_states))
502      self.assertAllEqual(last_and_saved_states[:2], last_and_saved_states[2:])
503
504  @test_util.run_v1_only("b/124229375")
505  def testNoProjNoShardingNestedTupleStateSaver(self):
506    num_units = 3
507    input_size = 5
508    batch_size = 2
509    max_length = 8
510    with self.session(graph=ops.Graph()) as sess:
511      initializer = init_ops.random_uniform_initializer(
512          -0.01, 0.01, seed=self._seed)
513      state_saver = TestStateSaver(
514          batch_size, {
515              "c0": num_units,
516              "m0": num_units,
517              "c1": num_units + 1,
518              "m1": num_units + 1,
519              "c2": num_units + 2,
520              "m2": num_units + 2,
521              "c3": num_units + 3,
522              "m3": num_units + 3
523          })
524
525      def _cell(i):
526        return rnn_cell.LSTMCell(
527            num_units + i,
528            use_peepholes=False,
529            initializer=initializer,
530            state_is_tuple=True)
531
532      # This creates a state tuple which has 4 sub-tuples of length 2 each.
533      cell = rnn_cell.MultiRNNCell(
534          [_cell(i) for i in range(4)], state_is_tuple=True)
535
536      self.assertEqual(len(cell.state_size), 4)
537      for i in range(4):
538        self.assertEqual(len(cell.state_size[i]), 2)
539
540      inputs = max_length * [
541          array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size))
542      ]
543
544      state_names = (("c0", "m0"), ("c1", "m1"), ("c2", "m2"), ("c3", "m3"))
545      with variable_scope.variable_scope("share_scope"):
546        outputs, state = rnn.static_state_saving_rnn(
547            cell, inputs, state_saver=state_saver, state_name=state_names)
548      self.assertEqual(len(outputs), len(inputs))
549
550      # Final output comes from _cell(3) which has state size num_units + 3
551      for out in outputs:
552        self.assertEqual(out.get_shape().as_list(), [batch_size, num_units + 3])
553
554      variables_lib.global_variables_initializer().run()
555      input_value = np.random.randn(batch_size, input_size)
556      last_states = sess.run(
557          list(nest.flatten(state)), feed_dict={
558              inputs[0]: input_value
559          })
560      saved_states = sess.run(
561          list(state_saver.saved_state.values()),
562          feed_dict={
563              inputs[0]: input_value
564          })
565      self.assertEqual(8, len(last_states))
566      self.assertEqual(8, len(saved_states))
567      flat_state_names = nest.flatten(state_names)
568      named_saved_states = dict(
569          zip(state_saver.saved_state.keys(), saved_states))
570
571      for i in range(8):
572        self.assertAllEqual(last_states[i],
573                            named_saved_states[flat_state_names[i]])
574
575  @test_util.run_v1_only("b/124229375")
576  def testProjNoSharding(self):
577    num_units = 3
578    input_size = 5
579    batch_size = 2
580    num_proj = 4
581    max_length = 8
582    with self.session(graph=ops.Graph()) as sess:
583      initializer = init_ops.random_uniform_initializer(
584          -0.01, 0.01, seed=self._seed)
585      inputs = max_length * [
586          array_ops.placeholder(dtypes.float32, shape=(None, input_size))
587      ]
588      cell = rnn_cell.LSTMCell(
589          num_units,
590          use_peepholes=True,
591          num_proj=num_proj,
592          initializer=initializer,
593          state_is_tuple=False)
594      outputs, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
595      self.assertEqual(len(outputs), len(inputs))
596
597      variables_lib.global_variables_initializer().run()
598      input_value = np.random.randn(batch_size, input_size)
599      sess.run(outputs, feed_dict={inputs[0]: input_value})
600
601  def _testStateTupleWithProjAndSequenceLength(self):
602    num_units = 3
603    input_size = 5
604    batch_size = 2
605    num_proj = 4
606    max_length = 8
607    sequence_length = [4, 6]
608    with self.session(graph=ops.Graph()) as sess:
609      initializer = init_ops.random_uniform_initializer(
610          -0.01, 0.01, seed=self._seed)
611      inputs = max_length * [
612          array_ops.placeholder(dtypes.float32, shape=(None, input_size))
613      ]
614      cell_notuple = rnn_cell.LSTMCell(
615          num_units,
616          use_peepholes=True,
617          num_proj=num_proj,
618          initializer=initializer,
619          state_is_tuple=False)
620      cell_tuple = rnn_cell.LSTMCell(
621          num_units,
622          use_peepholes=True,
623          num_proj=num_proj,
624          initializer=initializer,
625          state_is_tuple=True)
626      with variable_scope.variable_scope("root") as scope:
627        outputs_notuple, state_notuple = rnn.static_rnn(
628            cell_notuple,
629            inputs,
630            dtype=dtypes.float32,
631            sequence_length=sequence_length,
632            scope=scope)
633        scope.reuse_variables()
634        # TODO(ebrevdo): For this test, we ensure values are identical and
635        # therefore the weights here are tied.  In the future, we may consider
636        # making the state_is_tuple property mutable so we can avoid
637        # having to do this - especially if users ever need to reuse
638        # the parameters from different RNNCell instances.  Right now,
639        # this seems an unrealistic use case except for testing.
640        cell_tuple._scope = cell_notuple._scope  # pylint: disable=protected-access
641        outputs_tuple, state_tuple = rnn.static_rnn(
642            cell_tuple,
643            inputs,
644            dtype=dtypes.float32,
645            sequence_length=sequence_length,
646            scope=scope)
647      self.assertEqual(len(outputs_notuple), len(inputs))
648      self.assertEqual(len(outputs_tuple), len(inputs))
649      self.assertTrue(isinstance(state_tuple, tuple))
650      self.assertTrue(isinstance(state_notuple, ops.Tensor))
651
652      variables_lib.global_variables_initializer().run()
653      input_value = np.random.randn(batch_size, input_size)
654      outputs_notuple_v = sess.run(
655          outputs_notuple, feed_dict={
656              inputs[0]: input_value
657          })
658      outputs_tuple_v = sess.run(
659          outputs_tuple, feed_dict={
660              inputs[0]: input_value
661          })
662      self.assertAllEqual(outputs_notuple_v, outputs_tuple_v)
663
664      (state_notuple_v,) = sess.run(
665          (state_notuple,), feed_dict={
666              inputs[0]: input_value
667          })
668      state_tuple_v = sess.run(state_tuple, feed_dict={inputs[0]: input_value})
669      self.assertAllEqual(state_notuple_v, np.hstack(state_tuple_v))
670
671  @test_util.run_v1_only("b/124229375")
672  def testProjSharding(self):
673    num_units = 3
674    input_size = 5
675    batch_size = 2
676    num_proj = 4
677    num_proj_shards = 3
678    num_unit_shards = 2
679    max_length = 8
680    with self.session(graph=ops.Graph()) as sess:
681      initializer = init_ops.random_uniform_initializer(
682          -0.01, 0.01, seed=self._seed)
683
684      inputs = max_length * [
685          array_ops.placeholder(dtypes.float32, shape=(None, input_size))
686      ]
687
688      cell = rnn_cell.LSTMCell(
689          num_units,
690          use_peepholes=True,
691          num_proj=num_proj,
692          num_unit_shards=num_unit_shards,
693          num_proj_shards=num_proj_shards,
694          initializer=initializer,
695          state_is_tuple=False)
696
697      outputs, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
698
699      self.assertEqual(len(outputs), len(inputs))
700
701      variables_lib.global_variables_initializer().run()
702      input_value = np.random.randn(batch_size, input_size)
703      sess.run(outputs, feed_dict={inputs[0]: input_value})
704
705  @test_util.run_v1_only("b/124229375")
706  def testDoubleInput(self):
707    num_units = 3
708    input_size = 5
709    batch_size = 2
710    num_proj = 4
711    num_proj_shards = 3
712    num_unit_shards = 2
713    max_length = 8
714    with self.session(graph=ops.Graph()) as sess:
715      initializer = init_ops.random_uniform_initializer(-1, 1, seed=self._seed)
716      inputs = max_length * [
717          array_ops.placeholder(dtypes.float64, shape=(None, input_size))
718      ]
719
720      cell = rnn_cell.LSTMCell(
721          num_units,
722          use_peepholes=True,
723          num_proj=num_proj,
724          num_unit_shards=num_unit_shards,
725          num_proj_shards=num_proj_shards,
726          initializer=initializer,
727          state_is_tuple=False)
728
729      outputs, _ = rnn.static_rnn(
730          cell,
731          inputs,
732          initial_state=cell.zero_state(batch_size, dtypes.float64))
733
734      self.assertEqual(len(outputs), len(inputs))
735
736      variables_lib.global_variables_initializer().run()
737      input_value = np.asarray(
738          np.random.randn(batch_size, input_size), dtype=np.float64)
739      values = sess.run(outputs, feed_dict={inputs[0]: input_value})
740      self.assertEqual(values[0].dtype, input_value.dtype)
741
742  @test_util.run_v1_only("b/124229375")
743  def testShardNoShardEquivalentOutput(self):
744    num_units = 3
745    input_size = 5
746    batch_size = 2
747    num_proj = 4
748    num_proj_shards = 3
749    num_unit_shards = 2
750    max_length = 8
751    with self.session(graph=ops.Graph()) as sess:
752      inputs = max_length * [
753          array_ops.placeholder(dtypes.float32, shape=(None, input_size))
754      ]
755      initializer = init_ops.constant_initializer(0.001)
756
757      cell_noshard = rnn_cell.LSTMCell(
758          num_units,
759          num_proj=num_proj,
760          use_peepholes=True,
761          initializer=initializer,
762          num_unit_shards=num_unit_shards,
763          num_proj_shards=num_proj_shards,
764          state_is_tuple=False)
765
766      cell_shard = rnn_cell.LSTMCell(
767          num_units,
768          use_peepholes=True,
769          initializer=initializer,
770          num_proj=num_proj,
771          state_is_tuple=False)
772
773      with variable_scope.variable_scope("noshard_scope"):
774        outputs_noshard, state_noshard = rnn.static_rnn(
775            cell_noshard, inputs, dtype=dtypes.float32)
776      with variable_scope.variable_scope("shard_scope"):
777        outputs_shard, state_shard = rnn.static_rnn(
778            cell_shard, inputs, dtype=dtypes.float32)
779
780      self.assertEqual(len(outputs_noshard), len(inputs))
781      self.assertEqual(len(outputs_noshard), len(outputs_shard))
782
783      variables_lib.global_variables_initializer().run()
784      input_value = np.random.randn(batch_size, input_size)
785      feeds = dict((x, input_value) for x in inputs)
786      values_noshard = sess.run(outputs_noshard, feed_dict=feeds)
787      values_shard = sess.run(outputs_shard, feed_dict=feeds)
788      state_values_noshard = sess.run([state_noshard], feed_dict=feeds)
789      state_values_shard = sess.run([state_shard], feed_dict=feeds)
790      self.assertEqual(len(values_noshard), len(values_shard))
791      self.assertEqual(len(state_values_noshard), len(state_values_shard))
792      for (v_noshard, v_shard) in zip(values_noshard, values_shard):
793        self.assertAllClose(v_noshard, v_shard, atol=1e-3)
794      for (s_noshard, s_shard) in zip(state_values_noshard, state_values_shard):
795        self.assertAllClose(s_noshard, s_shard, atol=1e-3)
796
797  @test_util.run_v1_only("b/124229375")
798  def testDoubleInputWithDropoutAndDynamicCalculation(self):
799    """Smoke test for using LSTM with doubles, dropout, dynamic calculation."""
800
801    num_units = 3
802    input_size = 5
803    batch_size = 2
804    num_proj = 4
805    num_proj_shards = 3
806    num_unit_shards = 2
807    max_length = 8
808    with self.session(graph=ops.Graph()) as sess:
809      sequence_length = array_ops.placeholder(dtypes.int64)
810      initializer = init_ops.random_uniform_initializer(
811          -0.01, 0.01, seed=self._seed)
812      inputs = max_length * [
813          array_ops.placeholder(dtypes.float64, shape=(None, input_size))
814      ]
815
816      cell = rnn_cell.LSTMCell(
817          num_units,
818          use_peepholes=True,
819          num_proj=num_proj,
820          num_unit_shards=num_unit_shards,
821          num_proj_shards=num_proj_shards,
822          initializer=initializer,
823          state_is_tuple=False)
824      dropout_cell = rnn_cell.DropoutWrapper(cell, 0.5, seed=0)
825
826      outputs, state = rnn.static_rnn(
827          dropout_cell,
828          inputs,
829          sequence_length=sequence_length,
830          initial_state=cell.zero_state(batch_size, dtypes.float64))
831
832      self.assertEqual(len(outputs), len(inputs))
833
834      variables_lib.global_variables_initializer().run(feed_dict={
835          sequence_length: [2, 3]
836      })
837      input_value = np.asarray(
838          np.random.randn(batch_size, input_size), dtype=np.float64)
839      values = sess.run(
840          outputs, feed_dict={
841              inputs[0]: input_value,
842              sequence_length: [2, 3]
843          })
844      state_value = sess.run(
845          [state], feed_dict={
846              inputs[0]: input_value,
847              sequence_length: [2, 3]
848          })
849      self.assertEqual(values[0].dtype, input_value.dtype)
850      self.assertEqual(state_value[0].dtype, input_value.dtype)
851
852  @test_util.run_v1_only("b/124229375")
853  def testSharingWeightsWithReuse(self):
854    num_units = 3
855    input_size = 5
856    batch_size = 2
857    num_proj = 4
858    max_length = 8
859    with self.session(graph=ops.Graph()) as sess:
860      initializer = init_ops.random_uniform_initializer(-1, 1, seed=self._seed)
861      initializer_d = init_ops.random_uniform_initializer(
862          -1, 1, seed=self._seed + 1)
863      inputs = max_length * [
864          array_ops.placeholder(dtypes.float32, shape=(None, input_size))
865      ]
866      cell = rnn_cell.LSTMCell(
867          num_units,
868          use_peepholes=True,
869          num_proj=num_proj,
870          initializer=initializer,
871          state_is_tuple=False)
872      cell_d = rnn_cell.LSTMCell(
873          num_units,
874          use_peepholes=True,
875          num_proj=num_proj,
876          initializer=initializer_d,
877          state_is_tuple=False)
878
879      with variable_scope.variable_scope("share_scope"):
880        outputs0, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
881      with variable_scope.variable_scope("share_scope", reuse=True):
882        outputs1, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
883      with variable_scope.variable_scope("diff_scope"):
884        outputs2, _ = rnn.static_rnn(cell_d, inputs, dtype=dtypes.float32)
885
886      variables_lib.global_variables_initializer().run()
887      input_value = np.random.randn(batch_size, input_size)
888      output_values = sess.run(
889          outputs0 + outputs1 + outputs2, feed_dict={
890              inputs[0]: input_value
891          })
892      outputs0_values = output_values[:max_length]
893      outputs1_values = output_values[max_length:2 * max_length]
894      outputs2_values = output_values[2 * max_length:]
895      self.assertEqual(len(outputs0_values), len(outputs1_values))
896      self.assertEqual(len(outputs0_values), len(outputs2_values))
897      for o1, o2, o3 in zip(outputs0_values, outputs1_values, outputs2_values):
898        # Same weights used by both RNNs so outputs should be the same.
899        self.assertAllEqual(o1, o2)
900        # Different weights used so outputs should be different.
901        self.assertTrue(np.linalg.norm(o1 - o3) > 1e-6)
902
903  @test_util.run_v1_only("b/124229375")
904  def testSharingWeightsWithDifferentNamescope(self):
905    num_units = 3
906    input_size = 5
907    batch_size = 2
908    num_proj = 4
909    max_length = 8
910    with self.session(graph=ops.Graph()) as sess:
911      initializer = init_ops.random_uniform_initializer(-1, 1, seed=self._seed)
912      inputs = max_length * [
913          array_ops.placeholder(dtypes.float32, shape=(None, input_size))
914      ]
915      cell = rnn_cell.LSTMCell(
916          num_units,
917          use_peepholes=True,
918          num_proj=num_proj,
919          initializer=initializer,
920          state_is_tuple=False)
921
922      with ops.name_scope("scope0"):
923        with variable_scope.variable_scope("share_scope"):
924          outputs0, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
925      with ops.name_scope("scope1"):
926        with variable_scope.variable_scope("share_scope", reuse=True):
927          outputs1, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
928
929      variables_lib.global_variables_initializer().run()
930      input_value = np.random.randn(batch_size, input_size)
931      output_values = sess.run(
932          outputs0 + outputs1, feed_dict={
933              inputs[0]: input_value
934          })
935      outputs0_values = output_values[:max_length]
936      outputs1_values = output_values[max_length:]
937      self.assertEqual(len(outputs0_values), len(outputs1_values))
938      for out0, out1 in zip(outputs0_values, outputs1_values):
939        self.assertAllEqual(out0, out1)
940
941  @test_util.run_v1_only("b/124229375")
942  def testDynamicRNNAllowsUnknownTimeDimension(self):
943    inputs = array_ops.placeholder(dtypes.float32, shape=[1, None, 20])
944    cell = rnn_cell.GRUCell(30)
945    # Smoke test, this should not raise an error
946    rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32)
947
948  @test_util.run_in_graph_and_eager_modes
949  def testDynamicRNNWithTupleStates(self):
950    num_units = 3
951    input_size = 5
952    batch_size = 2
953    num_proj = 4
954    max_length = 8
955    sequence_length = [4, 6]
956    in_graph_mode = not context.executing_eagerly()
957    with self.session(graph=ops.Graph()) as sess:
958      initializer = init_ops.random_uniform_initializer(
959          -0.01, 0.01, seed=self._seed)
960      if in_graph_mode:
961        inputs = max_length * [
962            array_ops.placeholder(dtypes.float32, shape=(None, input_size))
963        ]
964      else:
965        inputs = max_length * [
966            constant_op.constant(
967                np.random.randn(batch_size, input_size).astype(np.float32))
968        ]
969      inputs_c = array_ops.stack(inputs)
970      cell = rnn_cell.LSTMCell(
971          num_units,
972          use_peepholes=True,
973          num_proj=num_proj,
974          initializer=initializer,
975          state_is_tuple=True)
976      with variable_scope.variable_scope("root") as scope:
977        outputs_static, state_static = rnn.static_rnn(
978            cell,
979            inputs,
980            dtype=dtypes.float32,
981            sequence_length=sequence_length,
982            scope=scope)
983        scope.reuse_variables()
984        outputs_dynamic, state_dynamic = rnn.dynamic_rnn(
985            cell,
986            inputs_c,
987            dtype=dtypes.float32,
988            time_major=True,
989            sequence_length=sequence_length,
990            scope=scope)
991      self.assertTrue(isinstance(state_static, rnn_cell.LSTMStateTuple))
992      self.assertTrue(isinstance(state_dynamic, rnn_cell.LSTMStateTuple))
993      self.assertIs(state_static[0], state_static.c)
994      self.assertIs(state_static[1], state_static.h)
995      self.assertIs(state_dynamic[0], state_dynamic.c)
996      self.assertIs(state_dynamic[1], state_dynamic.h)
997
998      if in_graph_mode:
999        variables_lib.global_variables_initializer().run()
1000        input_value = np.random.randn(batch_size, input_size)
1001        outputs_static = sess.run(
1002            outputs_static, feed_dict={
1003                inputs[0]: input_value
1004            })
1005        outputs_dynamic = sess.run(
1006            outputs_dynamic, feed_dict={
1007                inputs[0]: input_value
1008            })
1009        state_static = sess.run(
1010            state_static, feed_dict={
1011                inputs[0]: input_value
1012            })
1013        state_dynamic = sess.run(
1014            state_dynamic, feed_dict={
1015                inputs[0]: input_value
1016            })
1017
1018      comparison_fn = self.assertAllEqual
1019      if test_util.is_xla_enabled():
1020        comparison_fn = self.assertAllClose
1021      if in_graph_mode:
1022        comparison_fn(outputs_static, outputs_dynamic)
1023      else:
1024        self.assertAllEqual(array_ops.stack(outputs_static), outputs_dynamic)
1025      comparison_fn(np.hstack(state_static), np.hstack(state_dynamic))
1026
1027  @test_util.run_in_graph_and_eager_modes
1028  def testDynamicRNNWithNestedTupleStates(self):
1029    num_units = 3
1030    input_size = 5
1031    batch_size = 2
1032    num_proj = 4
1033    max_length = 8
1034    sequence_length = [4, 6]
1035    in_graph_mode = not context.executing_eagerly()
1036    with self.session(graph=ops.Graph()) as sess:
1037      initializer = init_ops.random_uniform_initializer(
1038          -0.01, 0.01, seed=self._seed)
1039      if in_graph_mode:
1040        inputs = max_length * [
1041            array_ops.placeholder(dtypes.float32, shape=(None, input_size))
1042        ]
1043      else:
1044        inputs = max_length * [
1045            constant_op.constant(
1046                np.random.randn(batch_size, input_size).astype(np.float32))
1047        ]
1048      inputs_c = array_ops.stack(inputs)
1049
1050      def _cell(i):
1051        return rnn_cell.LSTMCell(
1052            num_units + i,
1053            use_peepholes=True,
1054            num_proj=num_proj + i,
1055            initializer=initializer,
1056            state_is_tuple=True)
1057
1058      # This creates a state tuple which has 4 sub-tuples of length 2 each.
1059      cell = rnn_cell.MultiRNNCell(
1060          [_cell(i) for i in range(4)], state_is_tuple=True)
1061
1062      self.assertEqual(len(cell.state_size), 4)
1063      for i in range(4):
1064        self.assertEqual(len(cell.state_size[i]), 2)
1065
1066      test_zero = cell.zero_state(1, dtypes.float32)
1067      self.assertEqual(len(test_zero), 4)
1068      for i in range(4):
1069        self.assertEqual(test_zero[i][0].get_shape()[1], cell.state_size[i][0])
1070        self.assertEqual(test_zero[i][1].get_shape()[1], cell.state_size[i][1])
1071
1072      with variable_scope.variable_scope("root") as scope:
1073        outputs_static, state_static = rnn.static_rnn(
1074            cell,
1075            inputs,
1076            dtype=dtypes.float32,
1077            sequence_length=sequence_length,
1078            scope=scope)
1079        scope.reuse_variables()
1080        outputs_dynamic, state_dynamic = rnn.dynamic_rnn(
1081            cell,
1082            inputs_c,
1083            dtype=dtypes.float32,
1084            time_major=True,
1085            sequence_length=sequence_length,
1086            scope=scope)
1087
1088      if in_graph_mode:
1089        input_value = np.random.randn(batch_size, input_size)
1090        variables_lib.global_variables_initializer().run()
1091        outputs_static = sess.run(
1092            outputs_static, feed_dict={
1093                inputs[0]: input_value
1094            })
1095        outputs_dynamic = sess.run(
1096            outputs_dynamic, feed_dict={
1097                inputs[0]: input_value
1098            })
1099        state_static = sess.run(
1100            nest.flatten(state_static), feed_dict={
1101                inputs[0]: input_value
1102            })
1103        state_dynamic = sess.run(
1104            nest.flatten(state_dynamic), feed_dict={
1105                inputs[0]: input_value
1106            })
1107
1108      comparison_fn = self.assertAllEqual
1109      if test_util.is_xla_enabled():
1110        comparison_fn = self.assertAllClose
1111      if in_graph_mode:
1112        comparison_fn(outputs_static, outputs_dynamic)
1113      else:
1114        self.assertAllEqual(array_ops.stack(outputs_static), outputs_dynamic)
1115        state_static = nest.flatten(state_static)
1116        state_dynamic = nest.flatten(state_dynamic)
1117      comparison_fn(np.hstack(state_static), np.hstack(state_dynamic))
1118
1119  def _testDynamicEquivalentToStaticRNN(self, use_sequence_length):
1120    time_steps = 8
1121    num_units = 3
1122    num_proj = 4
1123    input_size = 5
1124    batch_size = 2
1125
1126    input_values = np.random.randn(time_steps, batch_size, input_size).astype(
1127        np.float32)
1128
1129    if use_sequence_length:
1130      sequence_length = np.random.randint(0, time_steps, size=batch_size)
1131    else:
1132      sequence_length = None
1133
1134    in_graph_mode = not context.executing_eagerly()
1135
1136    # TODO(b/68017812): Eager ignores operation seeds, so we need to create a
1137    # single cell and reuse it across the static and dynamic RNNs. Remove this
1138    # special case once is fixed.
1139    if not in_graph_mode:
1140      initializer = init_ops.random_uniform_initializer(
1141          -0.01, 0.01, seed=self._seed)
1142      cell = rnn_cell.LSTMCell(
1143          num_units,
1144          use_peepholes=True,
1145          initializer=initializer,
1146          num_proj=num_proj,
1147          state_is_tuple=False)
1148
1149    ########### Step 1: Run static graph and generate readouts
1150    with self.session(graph=ops.Graph()) as sess:
1151      if in_graph_mode:
1152        concat_inputs = array_ops.placeholder(
1153            dtypes.float32, shape=(time_steps, batch_size, input_size))
1154      else:
1155        concat_inputs = constant_op.constant(input_values)
1156      inputs = array_ops.unstack(concat_inputs)
1157      initializer = init_ops.random_uniform_initializer(
1158          -0.01, 0.01, seed=self._seed)
1159
1160      # TODO(akshayka): Remove special case once b/68017812 is fixed.
1161      if in_graph_mode:
1162        cell = rnn_cell.LSTMCell(
1163            num_units,
1164            use_peepholes=True,
1165            initializer=initializer,
1166            num_proj=num_proj,
1167            state_is_tuple=False)
1168
1169      with variable_scope.variable_scope("dynamic_scope"):
1170        outputs_static, state_static = rnn.static_rnn(
1171            cell, inputs, sequence_length=sequence_length, dtype=dtypes.float32)
1172
1173      if in_graph_mode:
1174        # Generate gradients of sum of outputs w.r.t. inputs
1175        static_gradients = gradients_impl.gradients(
1176            outputs_static + [state_static], [concat_inputs])
1177        # Generate gradients of individual outputs w.r.t. inputs
1178        static_individual_gradients = nest.flatten([
1179            gradients_impl.gradients(y, [concat_inputs])
1180            for y in [outputs_static[0], outputs_static[-1], state_static]
1181        ])
1182        # Generate gradients of individual variables w.r.t. inputs
1183        trainable_variables = ops.get_collection(
1184            ops.GraphKeys.TRAINABLE_VARIABLES)
1185        assert len(trainable_variables) > 1, (
1186            "Count of trainable variables: %d" % len(trainable_variables))
1187        # pylint: disable=bad-builtin
1188        static_individual_variable_gradients = nest.flatten([
1189            gradients_impl.gradients(y, trainable_variables)
1190            for y in [outputs_static[0], outputs_static[-1], state_static]
1191        ])
1192        # Generate gradients and run sessions to obtain outputs
1193        feeds = {concat_inputs: input_values}
1194        # Initialize
1195        variables_lib.global_variables_initializer().run(feed_dict=feeds)
1196        # Test forward pass
1197        values_static = sess.run(outputs_static, feed_dict=feeds)
1198        (state_value_static,) = sess.run((state_static,), feed_dict=feeds)
1199
1200        # Test gradients to inputs and variables w.r.t. outputs & final state
1201        static_grad_values = sess.run(static_gradients, feed_dict=feeds)
1202
1203        static_individual_grad_values = sess.run(
1204            static_individual_gradients, feed_dict=feeds)
1205
1206        static_individual_var_grad_values = sess.run(
1207            static_individual_variable_gradients, feed_dict=feeds)
1208
1209    ########## Step 2: Run dynamic graph and generate readouts
1210    with self.session(graph=ops.Graph()) as sess:
1211      if in_graph_mode:
1212        concat_inputs = array_ops.placeholder(
1213            dtypes.float32, shape=(time_steps, batch_size, input_size))
1214      else:
1215        concat_inputs = constant_op.constant(input_values)
1216      initializer = init_ops.random_uniform_initializer(
1217          -0.01, 0.01, seed=self._seed)
1218
1219      # TODO(akshayka): Remove this special case once b/68017812 is
1220      # fixed.
1221      if in_graph_mode:
1222        cell = rnn_cell.LSTMCell(
1223            num_units,
1224            use_peepholes=True,
1225            initializer=initializer,
1226            num_proj=num_proj,
1227            state_is_tuple=False)
1228
1229      with variable_scope.variable_scope("dynamic_scope"):
1230        outputs_dynamic, state_dynamic = rnn.dynamic_rnn(
1231            cell,
1232            inputs=concat_inputs,
1233            sequence_length=sequence_length,
1234            time_major=True,
1235            dtype=dtypes.float32)
1236        split_outputs_dynamic = array_ops.unstack(outputs_dynamic, time_steps)
1237
1238      if in_graph_mode:
1239
1240        # Generate gradients of sum of outputs w.r.t. inputs
1241        dynamic_gradients = gradients_impl.gradients(
1242            split_outputs_dynamic + [state_dynamic], [concat_inputs])
1243
1244        # Generate gradients of several individual outputs w.r.t. inputs
1245        dynamic_individual_gradients = nest.flatten([
1246            gradients_impl.gradients(y, [concat_inputs])
1247            for y in [
1248                split_outputs_dynamic[0], split_outputs_dynamic[-1],
1249                state_dynamic
1250            ]
1251        ])
1252
1253        # Generate gradients of individual variables w.r.t. inputs
1254        trainable_variables = ops.get_collection(
1255            ops.GraphKeys.TRAINABLE_VARIABLES)
1256        assert len(trainable_variables) > 1, (
1257            "Count of trainable variables: %d" % len(trainable_variables))
1258        dynamic_individual_variable_gradients = nest.flatten([
1259            gradients_impl.gradients(y, trainable_variables)
1260            for y in [
1261                split_outputs_dynamic[0], split_outputs_dynamic[-1],
1262                state_dynamic
1263            ]
1264        ])
1265
1266        feeds = {concat_inputs: input_values}
1267
1268        # Initialize
1269        variables_lib.global_variables_initializer().run(feed_dict=feeds)
1270
1271        # Test forward pass
1272        values_dynamic = sess.run(split_outputs_dynamic, feed_dict=feeds)
1273        (state_value_dynamic,) = sess.run((state_dynamic,), feed_dict=feeds)
1274
1275        # Test gradients to inputs and variables w.r.t. outputs & final state
1276        dynamic_grad_values = sess.run(dynamic_gradients, feed_dict=feeds)
1277
1278        dynamic_individual_grad_values = sess.run(
1279            dynamic_individual_gradients, feed_dict=feeds)
1280
1281        dynamic_individual_var_grad_values = sess.run(
1282            dynamic_individual_variable_gradients, feed_dict=feeds)
1283
1284    ######### Step 3: Comparisons
1285    if not in_graph_mode:
1286      values_static = outputs_static
1287      values_dynamic = split_outputs_dynamic
1288      state_value_static = state_static
1289      state_value_dynamic = state_dynamic
1290
1291    self.assertEqual(len(values_static), len(values_dynamic))
1292    for (value_static, value_dynamic) in zip(values_static, values_dynamic):
1293      self.assertAllClose(value_static, value_dynamic)
1294    self.assertAllClose(state_value_static, state_value_dynamic)
1295
1296    if in_graph_mode:
1297
1298      self.assertAllClose(static_grad_values, dynamic_grad_values)
1299
1300      self.assertEqual(
1301          len(static_individual_grad_values),
1302          len(dynamic_individual_grad_values))
1303      self.assertEqual(
1304          len(static_individual_var_grad_values),
1305          len(dynamic_individual_var_grad_values))
1306
1307      for i, (a, b) in enumerate(
1308          zip(static_individual_grad_values, dynamic_individual_grad_values)):
1309        tf_logging.info("Comparing individual gradients iteration %d" % i)
1310        self.assertAllClose(a, b)
1311
1312      for i, (a, b) in enumerate(
1313          zip(static_individual_var_grad_values,
1314              dynamic_individual_var_grad_values)):
1315        tf_logging.info(
1316            "Comparing individual variable gradients iteration %d" % i)
1317        self.assertAllClose(a, b)
1318
1319  @test_util.run_in_graph_and_eager_modes
1320  def testDynamicEquivalentToStaticRNN(self):
1321    self._testDynamicEquivalentToStaticRNN(use_sequence_length=False)
1322
1323  @test_util.run_in_graph_and_eager_modes
1324  def testDynamicEquivalentToStaticRNNWithSequenceLength(self):
1325    self._testDynamicEquivalentToStaticRNN(use_sequence_length=True)
1326
1327  @test_util.run_in_graph_and_eager_modes
1328  def testLSTMBlockCellErrorHandling(self):
1329    forget_bias = 1
1330    cell_clip = 0
1331    use_peephole = False
1332    x = constant_op.constant(0.837607, shape=[28, 29], dtype=dtypes.float32)
1333    cs_prev = constant_op.constant(0, shape=[28, 17], dtype=dtypes.float32)
1334    h_prev = constant_op.constant(
1335        0.592631638, shape=[28, 17], dtype=dtypes.float32)
1336    w = constant_op.constant(0.887386262, shape=[46, 68], dtype=dtypes.float32)
1337    wci = constant_op.constant(0, shape=[], dtype=dtypes.float32)
1338    wcf = constant_op.constant(0, shape=[17], dtype=dtypes.float32)
1339    wco = constant_op.constant(
1340        0.592631638, shape=[28, 17], dtype=dtypes.float32)
1341    b = constant_op.constant(0.75259006, shape=[68], dtype=dtypes.float32)
1342    with self.assertRaises(errors_impl.InvalidArgumentError):
1343      self.evaluate(
1344          gen_rnn_ops.lstm_block_cell(
1345              x=x,
1346              cs_prev=cs_prev,
1347              h_prev=h_prev,
1348              w=w,
1349              wci=wci,
1350              wcf=wcf,
1351              wco=wco,
1352              b=b,
1353              forget_bias=forget_bias,
1354              cell_clip=cell_clip,
1355              use_peephole=use_peephole))
1356
1357  @test_util.run_in_graph_and_eager_modes
1358  def testLSTMBlockCellGradErrorHandling(self):
1359    use_peephole = False
1360    seq_len_max = constant_op.constant(1, shape=[], dtype=dtypes.int64)
1361    x = constant_op.constant(0.504355371, shape=[1, 1, 1], dtype=dtypes.float32)
1362    cs_prev = constant_op.constant(
1363        0.504355371, shape=[1, 1, 1], dtype=dtypes.float32)
1364    h_prev = constant_op.constant(
1365        0.504355371, shape=[1, 1], dtype=dtypes.float32)
1366    w = constant_op.constant(0.504355371, shape=[1, 1], dtype=dtypes.float32)
1367    wci = constant_op.constant(0.504355371, shape=[1], dtype=dtypes.float32)
1368    wcf = constant_op.constant(0.504355371, shape=[1], dtype=dtypes.float32)
1369    wco = constant_op.constant(0.504355371, shape=[1], dtype=dtypes.float32)
1370    b = constant_op.constant(0.504355371, shape=[1], dtype=dtypes.float32)
1371    i = constant_op.constant(0.504355371, shape=[1, 1, 1], dtype=dtypes.float32)
1372    cs = constant_op.constant(
1373        0.504355371, shape=[1, 1, 1], dtype=dtypes.float32)
1374    f = constant_op.constant(0.504355371, shape=[1, 1, 1], dtype=dtypes.float32)
1375    o = constant_op.constant(0.504355371, shape=[1, 1, 1], dtype=dtypes.float32)
1376    ci = constant_op.constant(
1377        0.504355371, shape=[1, 1, 1], dtype=dtypes.float32)
1378    co = constant_op.constant(
1379        0.504355371, shape=[1, 1, 1], dtype=dtypes.float32)
1380    h = constant_op.constant(0.504355371, shape=[1, 1, 1], dtype=dtypes.float32)
1381    cs_grad = constant_op.constant(
1382        0.504355371, shape=[1, 1, 1], dtype=dtypes.float32)
1383    h_grad = constant_op.constant(
1384        0.504355371, shape=[1, 1, 1], dtype=dtypes.float32)
1385    with self.assertRaisesRegex((ValueError, errors_impl.InvalidArgumentError),
1386                                "must be rank"):
1387      self.evaluate(
1388          gen_rnn_ops.block_lstm_grad_v2(
1389              seq_len_max=seq_len_max,
1390              x=x,
1391              cs_prev=cs_prev,
1392              h_prev=h_prev,
1393              w=w,
1394              wci=wci,
1395              wcf=wcf,
1396              wco=wco,
1397              b=b,
1398              i=i,
1399              cs=cs,
1400              f=f,
1401              o=o,
1402              ci=ci,
1403              co=co,
1404              h=h,
1405              cs_grad=cs_grad,
1406              h_grad=h_grad,
1407              use_peephole=use_peephole))
1408
1409
1410class BidirectionalRNNTest(test.TestCase):
1411
1412  def setUp(self):
1413    self._seed = 23489
1414    np.random.seed(self._seed)
1415
1416  def _createBidirectionalRNN(self, use_shape, use_sequence_length, scope=None):
1417    num_units = 3
1418    input_size = 5
1419    batch_size = 2
1420    max_length = 8
1421
1422    initializer = init_ops.random_uniform_initializer(
1423        -0.01, 0.01, seed=self._seed)
1424    sequence_length = array_ops.placeholder(
1425        dtypes.int64) if use_sequence_length else None
1426    cell_fw = rnn_cell.LSTMCell(
1427        num_units, input_size, initializer=initializer, state_is_tuple=False)
1428    cell_bw = rnn_cell.LSTMCell(
1429        num_units, input_size, initializer=initializer, state_is_tuple=False)
1430    inputs = max_length * [
1431        array_ops.placeholder(
1432            dtypes.float32,
1433            shape=(batch_size, input_size) if use_shape else (None, input_size))
1434    ]
1435    outputs, state_fw, state_bw = rnn.static_bidirectional_rnn(
1436        cell_fw,
1437        cell_bw,
1438        inputs,
1439        dtype=dtypes.float32,
1440        sequence_length=sequence_length,
1441        scope=scope)
1442    self.assertEqual(len(outputs), len(inputs))
1443    for out in outputs:
1444      self.assertEqual(out.get_shape().as_list(),
1445                       [batch_size if use_shape else None, 2 * num_units])
1446
1447    input_value = np.random.randn(batch_size, input_size)
1448    outputs = array_ops.stack(outputs)
1449
1450    return input_value, inputs, outputs, state_fw, state_bw, sequence_length
1451
1452  def _testBidirectionalRNN(self, use_shape):
1453    with self.session(graph=ops.Graph()) as sess:
1454      input_value, inputs, outputs, state_fw, state_bw, sequence_length = (
1455          self._createBidirectionalRNN(use_shape, True))
1456      variables_lib.global_variables_initializer().run()
1457      # Run with pre-specified sequence length of 2, 3
1458      out, s_fw, s_bw = sess.run(
1459          [outputs, state_fw, state_bw],
1460          feed_dict={
1461              inputs[0]: input_value,
1462              sequence_length: [2, 3]
1463          })
1464
1465      # Since the forward and backward LSTM cells were initialized with the
1466      # same parameters, the forward and backward output has to be the same,
1467      # but reversed in time. The format is output[time][batch][depth], and
1468      # due to depth concatenation (as num_units=3 for both RNNs):
1469      # - forward output:  out[][][depth] for 0 <= depth < 3
1470      # - backward output: out[][][depth] for 4 <= depth < 6
1471      #
1472      # First sequence in batch is length=2
1473      # Check that the time=0 forward output is equal to time=1 backward output
1474      self.assertAllClose(out[0][0][0], out[1][0][3])
1475      self.assertAllClose(out[0][0][1], out[1][0][4])
1476      self.assertAllClose(out[0][0][2], out[1][0][5])
1477      # Check that the time=1 forward output is equal to time=0 backward output
1478      self.assertAllClose(out[1][0][0], out[0][0][3])
1479      self.assertAllClose(out[1][0][1], out[0][0][4])
1480      self.assertAllClose(out[1][0][2], out[0][0][5])
1481
1482      # Second sequence in batch is length=3
1483      # Check that the time=0 forward output is equal to time=2 backward output
1484      self.assertAllClose(out[0][1][0], out[2][1][3])
1485      self.assertAllClose(out[0][1][1], out[2][1][4])
1486      self.assertAllClose(out[0][1][2], out[2][1][5])
1487      # Check that the time=1 forward output is equal to time=1 backward output
1488      self.assertAllClose(out[1][1][0], out[1][1][3])
1489      self.assertAllClose(out[1][1][1], out[1][1][4])
1490      self.assertAllClose(out[1][1][2], out[1][1][5])
1491      # Check that the time=2 forward output is equal to time=0 backward output
1492      self.assertAllClose(out[2][1][0], out[0][1][3])
1493      self.assertAllClose(out[2][1][1], out[0][1][4])
1494      self.assertAllClose(out[2][1][2], out[0][1][5])
1495      # Via the reasoning above, the forward and backward final state should be
1496      # exactly the same
1497      self.assertAllClose(s_fw, s_bw)
1498
1499  def _testBidirectionalRNNWithoutSequenceLength(self, use_shape):
1500    with self.session(graph=ops.Graph()) as sess:
1501      input_value, inputs, outputs, state_fw, state_bw, _ = (
1502          self._createBidirectionalRNN(use_shape, False))
1503      variables_lib.global_variables_initializer().run()
1504      out, s_fw, s_bw = sess.run(
1505          [outputs, state_fw, state_bw], feed_dict={
1506              inputs[0]: input_value
1507          })
1508
1509      # Since the forward and backward LSTM cells were initialized with the
1510      # same parameters, the forward and backward output has to be the same,
1511      # but reversed in time. The format is output[time][batch][depth], and
1512      # due to depth concatenation (as num_units=3 for both RNNs):
1513      # - forward output:  out[][][depth] for 0 <= depth < 3
1514      # - backward output: out[][][depth] for 4 <= depth < 6
1515      #
1516      # Both sequences in batch are length=8.  Check that the time=i
1517      # forward output is equal to time=8-1-i backward output
1518      for i in range(8):
1519        self.assertAllClose(out[i][0][0:3], out[8 - 1 - i][0][3:6])
1520        self.assertAllClose(out[i][1][0:3], out[8 - 1 - i][1][3:6])
1521      # Via the reasoning above, the forward and backward final state should be
1522      # exactly the same
1523      self.assertAllClose(s_fw, s_bw)
1524
1525  @test_util.run_v1_only("b/124229375")
1526  def testBidirectionalRNN(self):
1527    self._testBidirectionalRNN(use_shape=False)
1528    self._testBidirectionalRNN(use_shape=True)
1529
1530  @test_util.run_v1_only("b/124229375")
1531  def testBidirectionalRNNWithoutSequenceLength(self):
1532    self._testBidirectionalRNNWithoutSequenceLength(use_shape=False)
1533    self._testBidirectionalRNNWithoutSequenceLength(use_shape=True)
1534
1535  def _createBidirectionalDynamicRNN(self,
1536                                     use_shape,
1537                                     use_state_tuple,
1538                                     use_time_major,
1539                                     use_sequence_length,
1540                                     scope=None):
1541    num_units = 3
1542    input_size = 5
1543    batch_size = 2
1544    max_length = 8
1545
1546    initializer = init_ops.random_uniform_initializer(
1547        -0.01, 0.01, seed=self._seed)
1548    sequence_length = (
1549        array_ops.placeholder(dtypes.int64) if use_sequence_length else None)
1550    cell_fw = rnn_cell.LSTMCell(
1551        num_units, initializer=initializer, state_is_tuple=use_state_tuple)
1552    cell_bw = rnn_cell.LSTMCell(
1553        num_units, initializer=initializer, state_is_tuple=use_state_tuple)
1554    inputs = max_length * [
1555        array_ops.placeholder(
1556            dtypes.float32,
1557            shape=(batch_size if use_shape else None, input_size))
1558    ]
1559    inputs_c = array_ops.stack(inputs)
1560    if not use_time_major:
1561      inputs_c = array_ops.transpose(inputs_c, [1, 0, 2])
1562    outputs, states = rnn.bidirectional_dynamic_rnn(
1563        cell_fw,
1564        cell_bw,
1565        inputs_c,
1566        sequence_length,
1567        dtype=dtypes.float32,
1568        time_major=use_time_major,
1569        scope=scope)
1570    outputs = array_ops.concat(outputs, 2)
1571    state_fw, state_bw = states
1572    outputs_shape = [None, max_length, 2 * num_units]
1573    if use_shape:
1574      outputs_shape[0] = batch_size
1575    if use_time_major:
1576      outputs_shape[0], outputs_shape[1] = outputs_shape[1], outputs_shape[0]
1577    self.assertEqual(outputs.get_shape().as_list(), outputs_shape)
1578
1579    input_value = np.random.randn(batch_size, input_size)
1580
1581    return input_value, inputs, outputs, state_fw, state_bw, sequence_length
1582
1583  def _testBidirectionalDynamicRNN(self, use_shape, use_state_tuple,
1584                                   use_time_major, use_sequence_length):
1585    with self.session(graph=ops.Graph()) as sess:
1586      input_value, inputs, outputs, state_fw, state_bw, sequence_length = (
1587          self._createBidirectionalDynamicRNN(
1588              use_shape, use_state_tuple, use_time_major, use_sequence_length))
1589      variables_lib.global_variables_initializer().run()
1590      # Run with pre-specified sequence length of 2, 3
1591      feed_dict = ({sequence_length: [2, 3]} if use_sequence_length else {})
1592      feed_dict.update({inputs[0]: input_value})
1593      if use_state_tuple:
1594        out, c_fw, m_fw, c_bw, m_bw = sess.run(
1595            [outputs, state_fw[0], state_fw[1], state_bw[0], state_bw[1]],
1596            feed_dict=feed_dict)
1597        s_fw = (c_fw, m_fw)
1598        s_bw = (c_bw, m_bw)
1599      else:
1600        feed_dict.update({inputs[0]: input_value})
1601        out, s_fw, s_bw = sess.run(
1602            [outputs, state_fw, state_bw], feed_dict=feed_dict)
1603
1604      # Since the forward and backward LSTM cells were initialized with the
1605      # same parameters, the forward and backward output has to be the same,
1606      # but reversed in time. The format is output[time][batch][depth], and
1607      # due to depth concatenation (as num_units=3 for both RNNs):
1608      # - forward output:  out[][][depth] for 0 <= depth < 3
1609      # - backward output: out[][][depth] for 4 <= depth < 6
1610      #
1611      if not use_time_major:
1612        out = np.swapaxes(out, 0, 1)
1613
1614      if use_sequence_length:
1615        # First sequence in batch is length=2
1616        # Check that the t=0 forward output is equal to t=1 backward output
1617        self.assertEqual(out[0][0][0], out[1][0][3])
1618        self.assertEqual(out[0][0][1], out[1][0][4])
1619        self.assertEqual(out[0][0][2], out[1][0][5])
1620        # Check that the t=1 forward output is equal to t=0 backward output
1621        self.assertEqual(out[1][0][0], out[0][0][3])
1622        self.assertEqual(out[1][0][1], out[0][0][4])
1623        self.assertEqual(out[1][0][2], out[0][0][5])
1624
1625        # Second sequence in batch is length=3
1626        # Check that the t=0 forward output is equal to t=2 backward output
1627        self.assertEqual(out[0][1][0], out[2][1][3])
1628        self.assertEqual(out[0][1][1], out[2][1][4])
1629        self.assertEqual(out[0][1][2], out[2][1][5])
1630        # Check that the t=1 forward output is equal to t=1 backward output
1631        self.assertEqual(out[1][1][0], out[1][1][3])
1632        self.assertEqual(out[1][1][1], out[1][1][4])
1633        self.assertEqual(out[1][1][2], out[1][1][5])
1634        # Check that the t=2 forward output is equal to t=0 backward output
1635        self.assertEqual(out[2][1][0], out[0][1][3])
1636        self.assertEqual(out[2][1][1], out[0][1][4])
1637        self.assertEqual(out[2][1][2], out[0][1][5])
1638        # Via the reasoning above, the forward and backward final state should
1639        # be exactly the same
1640        self.assertAllClose(s_fw, s_bw)
1641      else:  # not use_sequence_length
1642        max_length = 8  # from createBidirectionalDynamicRNN
1643        for t in range(max_length):
1644          self.assertAllEqual(out[t, :, 0:3], out[max_length - t - 1, :, 3:6])
1645        self.assertAllClose(s_fw, s_bw)
1646
1647  @test_util.run_v1_only("b/124229375")
1648  def testBidirectionalDynamicRNN(self):
1649    # Generate 2^5 option values
1650    # from [True, True, True, True, True] to [False, False, False, False, False]
1651    options = itertools.product([True, False], repeat=4)
1652    for option in options:
1653      self._testBidirectionalDynamicRNN(
1654          use_shape=option[0],
1655          use_state_tuple=option[1],
1656          use_time_major=option[2],
1657          use_sequence_length=option[3])
1658
1659  def _testScope(self, factory, prefix="prefix", use_outer_scope=True):
1660    # REMARKS: factory(scope) is a function accepting a scope
1661    #          as an argument, such scope can be None, a string
1662    #          or a VariableScope instance.
1663    with self.session(graph=ops.Graph()):
1664      if use_outer_scope:
1665        with variable_scope.variable_scope(prefix) as scope:
1666          factory(scope)
1667      else:
1668        factory(prefix)
1669
1670      # check that all the variables names starts
1671      # with the proper scope.
1672      variables_lib.global_variables_initializer()
1673      all_vars = variables_lib.global_variables()
1674      prefix = prefix or "bidirectional_rnn"
1675      scope_vars = [v for v in all_vars if v.name.startswith(prefix + "/")]
1676      tf_logging.info("BiRNN with scope: %s (%s)" %
1677                      (prefix, "scope" if use_outer_scope else "str"))
1678      for v in scope_vars:
1679        tf_logging.info(v.name)
1680      self.assertEqual(len(scope_vars), len(all_vars))
1681
1682  @test_util.run_v1_only("b/124229375")
1683  def testBidirectionalRNNScope(self):
1684
1685    def factory(scope):
1686      return self._createBidirectionalRNN(
1687          use_shape=True, use_sequence_length=True, scope=scope)
1688
1689    self._testScope(factory, use_outer_scope=True)
1690    self._testScope(factory, use_outer_scope=False)
1691    self._testScope(factory, prefix=None, use_outer_scope=False)
1692
1693  @test_util.run_v1_only("b/124229375")
1694  def testBidirectionalDynamicRNNScope(self):
1695
1696    def get_factory(use_time_major):
1697
1698      def factory(scope):
1699        return self._createBidirectionalDynamicRNN(
1700            use_shape=True,
1701            use_state_tuple=True,
1702            use_sequence_length=True,
1703            use_time_major=use_time_major,
1704            scope=scope)
1705
1706      return factory
1707
1708    self._testScope(get_factory(True), use_outer_scope=True)
1709    self._testScope(get_factory(True), use_outer_scope=False)
1710    self._testScope(get_factory(True), prefix=None, use_outer_scope=False)
1711    self._testScope(get_factory(False), use_outer_scope=True)
1712    self._testScope(get_factory(False), use_outer_scope=False)
1713    self._testScope(get_factory(False), prefix=None, use_outer_scope=False)
1714
1715
1716class MultiDimensionalLSTMTest(test.TestCase):
1717
1718  def setUp(self):
1719    self._seed = 23489
1720    np.random.seed(self._seed)
1721
1722  @test_util.run_v1_only("b/124229375")
1723  def testMultiDimensionalLSTMAllRNNContainers(self):
1724    feature_dims = (3, 4, 5)
1725    input_size = feature_dims
1726    batch_size = 2
1727    max_length = 8
1728    sequence_length = [4, 6]
1729    with self.session(graph=ops.Graph()) as sess:
1730      inputs = max_length * [
1731          array_ops.placeholder(dtypes.float32, shape=(None,) + input_size)
1732      ]
1733      inputs_using_dim = max_length * [
1734          array_ops.placeholder(
1735              dtypes.float32, shape=(batch_size,) + input_size)
1736      ]
1737      inputs_c = array_ops.stack(inputs)
1738      # Create a cell for the whole test. This is fine because the cell has no
1739      # variables.
1740      cell = DummyMultiDimensionalLSTM(feature_dims)
1741      state_saver = TestStateSaver(batch_size, input_size)
1742      outputs_static, state_static = rnn.static_rnn(
1743          cell, inputs, dtype=dtypes.float32, sequence_length=sequence_length)
1744      outputs_dynamic, state_dynamic = rnn.dynamic_rnn(
1745          cell,
1746          inputs_c,
1747          dtype=dtypes.float32,
1748          time_major=True,
1749          sequence_length=sequence_length)
1750      outputs_bid, state_fw, state_bw = rnn.static_bidirectional_rnn(
1751          cell,
1752          cell,
1753          inputs_using_dim,
1754          dtype=dtypes.float32,
1755          sequence_length=sequence_length)
1756      outputs_sav, state_sav = rnn.static_state_saving_rnn(
1757          cell,
1758          inputs_using_dim,
1759          sequence_length=sequence_length,
1760          state_saver=state_saver,
1761          state_name=("h", "c"))
1762
1763      self.assertEqual(outputs_dynamic.get_shape().as_list(),
1764                       inputs_c.get_shape().as_list())
1765      for out, inp in zip(outputs_static, inputs):
1766        self.assertEqual(out.get_shape().as_list(), inp.get_shape().as_list())
1767      for out, inp in zip(outputs_bid, inputs_using_dim):
1768        input_shape_list = inp.get_shape().as_list()
1769        # fwd and bwd activations are concatenated along the second dim.
1770        input_shape_list[1] *= 2
1771        self.assertEqual(out.get_shape().as_list(), input_shape_list)
1772
1773      variables_lib.global_variables_initializer().run()
1774
1775      input_total_size = (batch_size,) + input_size
1776      input_value = np.random.randn(*input_total_size)
1777      outputs_static_v = sess.run(
1778          outputs_static, feed_dict={
1779              inputs[0]: input_value
1780          })
1781      outputs_dynamic_v = sess.run(
1782          outputs_dynamic, feed_dict={
1783              inputs[0]: input_value
1784          })
1785      outputs_bid_v = sess.run(
1786          outputs_bid, feed_dict={
1787              inputs_using_dim[0]: input_value
1788          })
1789      outputs_sav_v = sess.run(
1790          outputs_sav, feed_dict={
1791              inputs_using_dim[0]: input_value
1792          })
1793
1794      self.assertAllEqual(outputs_static_v, outputs_dynamic_v)
1795      self.assertAllEqual(outputs_static_v, outputs_sav_v)
1796      outputs_static_array = np.array(outputs_static_v)
1797      outputs_static_array_double = np.concatenate(
1798          (outputs_static_array, outputs_static_array), axis=2)
1799      outputs_bid_array = np.array(outputs_bid_v)
1800      self.assertAllEqual(outputs_static_array_double, outputs_bid_array)
1801
1802      state_static_v = sess.run(
1803          state_static, feed_dict={
1804              inputs[0]: input_value
1805          })
1806      state_dynamic_v = sess.run(
1807          state_dynamic, feed_dict={
1808              inputs[0]: input_value
1809          })
1810      state_bid_fw_v = sess.run(
1811          state_fw, feed_dict={
1812              inputs_using_dim[0]: input_value
1813          })
1814      state_bid_bw_v = sess.run(
1815          state_bw, feed_dict={
1816              inputs_using_dim[0]: input_value
1817          })
1818      state_sav_v = sess.run(
1819          state_sav, feed_dict={
1820              inputs_using_dim[0]: input_value
1821          })
1822      self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_dynamic_v))
1823      self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_sav_v))
1824      self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_bid_fw_v))
1825      self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_bid_bw_v))
1826
1827
1828class NestedLSTMTest(test.TestCase):
1829
1830  def setUp(self):
1831    self._seed = 23489
1832    np.random.seed(self._seed)
1833
1834  @test_util.run_v1_only("b/124229375")
1835  def testNestedIOLSTMAllRNNContainers(self):
1836    input_size = 5
1837    batch_size = 2
1838    state_size = 6
1839    max_length = 8
1840    sequence_length = [4, 6]
1841    with self.session(graph=ops.Graph()) as sess:
1842      state_saver = TestStateSaver(batch_size, state_size)
1843      single_input = (array_ops.placeholder(
1844          dtypes.float32, shape=(None, input_size)),
1845                      array_ops.placeholder(
1846                          dtypes.float32, shape=(None, input_size)))
1847      inputs = max_length * [single_input]
1848      inputs_c = (array_ops.stack([input_[0] for input_ in inputs]),
1849                  array_ops.stack([input_[1] for input_ in inputs]))
1850      single_input_using_dim = (array_ops.placeholder(
1851          dtypes.float32, shape=(batch_size, input_size)),
1852                                array_ops.placeholder(
1853                                    dtypes.float32,
1854                                    shape=(batch_size, input_size)))
1855      inputs_using_dim = max_length * [single_input_using_dim]
1856
1857      # Create a cell for the whole test. This is fine because the cell has no
1858      # variables.
1859      cell = NestedRNNCell()
1860      outputs_dynamic, state_dynamic = rnn.dynamic_rnn(
1861          cell,
1862          inputs_c,
1863          dtype=dtypes.float32,
1864          time_major=True,
1865          sequence_length=sequence_length)
1866      outputs_static, state_static = rnn.static_rnn(
1867          cell, inputs, dtype=dtypes.float32, sequence_length=sequence_length)
1868      outputs_bid, state_fw, state_bw = rnn.static_bidirectional_rnn(
1869          cell,
1870          cell,
1871          inputs_using_dim,
1872          dtype=dtypes.float32,
1873          sequence_length=sequence_length)
1874      outputs_sav, state_sav = rnn.static_state_saving_rnn(
1875          cell,
1876          inputs_using_dim,
1877          sequence_length=sequence_length,
1878          state_saver=state_saver,
1879          state_name=("h", "c"))
1880
1881      def _assert_same_shape(input1, input2, double=False):
1882        flat_input1 = nest.flatten(input1)
1883        flat_input2 = nest.flatten(input2)
1884        for inp1, inp2 in zip(flat_input1, flat_input2):
1885          input_shape = inp1.get_shape().as_list()
1886          if double:
1887            input_shape[1] *= 2
1888          self.assertEqual(input_shape, inp2.get_shape().as_list())
1889
1890      _assert_same_shape(inputs_c, outputs_dynamic)
1891      _assert_same_shape(inputs, outputs_static)
1892      _assert_same_shape(inputs_using_dim, outputs_sav)
1893      _assert_same_shape(inputs_using_dim, outputs_bid, double=True)
1894
1895      variables_lib.global_variables_initializer().run()
1896
1897      input_total_size = (batch_size, input_size)
1898      input_value = (np.random.randn(*input_total_size),
1899                     np.random.randn(*input_total_size))
1900      outputs_dynamic_v = sess.run(
1901          outputs_dynamic, feed_dict={
1902              single_input: input_value
1903          })
1904      outputs_static_v = sess.run(
1905          outputs_static, feed_dict={
1906              single_input: input_value
1907          })
1908      outputs_sav_v = sess.run(
1909          outputs_sav, feed_dict={
1910              single_input_using_dim: input_value
1911          })
1912      outputs_bid_v = sess.run(
1913          outputs_bid, feed_dict={
1914              single_input_using_dim: input_value
1915          })
1916
1917      self.assertAllEqual(outputs_static_v,
1918                          np.transpose(outputs_dynamic_v, (1, 0, 2, 3)))
1919      self.assertAllEqual(outputs_static_v, outputs_sav_v)
1920      outputs_static_array = np.array(outputs_static_v)
1921      outputs_static_array_double = np.concatenate(
1922          (outputs_static_array, outputs_static_array), axis=3)
1923      outputs_bid_array = np.array(outputs_bid_v)
1924      self.assertAllEqual(outputs_static_array_double, outputs_bid_array)
1925
1926      state_dynamic_v = sess.run(
1927          state_dynamic, feed_dict={
1928              single_input: input_value
1929          })
1930      state_static_v = sess.run(
1931          state_static, feed_dict={
1932              single_input: input_value
1933          })
1934      state_bid_fw_v = sess.run(
1935          state_fw, feed_dict={
1936              single_input_using_dim: input_value
1937          })
1938      state_bid_bw_v = sess.run(
1939          state_bw, feed_dict={
1940              single_input_using_dim: input_value
1941          })
1942      state_sav_v = sess.run(
1943          state_sav, feed_dict={
1944              single_input_using_dim: input_value
1945          })
1946      self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_dynamic_v))
1947      self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_sav_v))
1948      self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_bid_fw_v))
1949      self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_bid_bw_v))
1950
1951
1952class StateSaverRNNTest(test.TestCase):
1953
1954  def setUp(self):
1955    self._seed = 23489
1956    np.random.seed(self._seed)
1957
1958  def _factory(self, scope, state_saver):
1959    num_units = state_saver.state_size // 2
1960    batch_size = state_saver.batch_size
1961    input_size = 5
1962    max_length = 8
1963    initializer = init_ops.random_uniform_initializer(
1964        -0.01, 0.01, seed=self._seed)
1965    cell = rnn_cell.LSTMCell(
1966        num_units,
1967        use_peepholes=False,
1968        initializer=initializer,
1969        state_is_tuple=False)
1970    inputs = max_length * [
1971        array_ops.zeros(dtype=dtypes.float32, shape=(batch_size, input_size))
1972    ]
1973    out, state = rnn.static_state_saving_rnn(
1974        cell,
1975        inputs,
1976        state_saver=state_saver,
1977        state_name="save_lstm",
1978        scope=scope)
1979    return out, state, state_saver
1980
1981  def _testScope(self, prefix="prefix", use_outer_scope=True):
1982    num_units = 3
1983    batch_size = 2
1984    state_saver = TestStateSaver(batch_size, 2 * num_units)
1985
1986    with self.session(graph=ops.Graph()):
1987      if use_outer_scope:
1988        with variable_scope.variable_scope(prefix) as scope:
1989          self._factory(scope=scope, state_saver=state_saver)
1990      else:
1991        self._factory(scope=prefix, state_saver=state_saver)
1992        variables_lib.global_variables_initializer()
1993
1994      # check that all the variables names starts
1995      # with the proper scope.
1996      all_vars = variables_lib.global_variables()
1997      prefix = prefix or "rnn"
1998      scope_vars = [v for v in all_vars if v.name.startswith(prefix + "/")]
1999      tf_logging.info("RNN with scope: %s (%s)" %
2000                      (prefix, "scope" if use_outer_scope else "str"))
2001      for v in scope_vars:
2002        tf_logging.info(v.name)
2003      self.assertEqual(len(scope_vars), len(all_vars))
2004
2005  def testStateSaverRNNScope(self):
2006    self._testScope(use_outer_scope=True)
2007    self._testScope(use_outer_scope=False)
2008    self._testScope(prefix=None, use_outer_scope=False)
2009
2010  def testStateSaverCallsSaveState(self):
2011    """Test that number of calls to state and save_state is equal.
2012
2013    Test if the order of actual evaluating or skipping evaluation of out,
2014    state tensors, which are the output tensors from static_state_saving_rnn,
2015    have influence on number of calls to save_state and state methods of
2016    state_saver object (the number of calls should be same.)
2017    """
2018    self.skipTest("b/124196246 Breakage for sess.run([out, ...]): 2 != 1")
2019
2020    num_units = 3
2021    batch_size = 2
2022    state_saver = TestStateSaverWithCounters(batch_size, 2 * num_units)
2023    out, state, state_saver = self._factory(scope=None, state_saver=state_saver)
2024
2025    with self.cached_session() as sess:
2026      sess.run(variables_lib.global_variables_initializer())
2027      sess.run(variables_lib.local_variables_initializer())
2028
2029      _, _, num_state_calls, num_save_state_calls = sess.run([
2030          out,
2031          state,
2032          state_saver.num_state_calls,
2033          state_saver.num_save_state_calls])
2034      self.assertEqual(num_state_calls, num_save_state_calls)
2035
2036      _, num_state_calls, num_save_state_calls = sess.run([
2037          out,
2038          state_saver.num_state_calls,
2039          state_saver.num_save_state_calls])
2040      self.assertEqual(num_state_calls, num_save_state_calls)
2041
2042      _, num_state_calls, num_save_state_calls = sess.run([
2043          state,
2044          state_saver.num_state_calls,
2045          state_saver.num_save_state_calls])
2046      self.assertEqual(num_state_calls, num_save_state_calls)
2047
2048class GRUTest(test.TestCase):
2049
2050  def setUp(self):
2051    self._seed = 23489
2052    np.random.seed(self._seed)
2053
2054  @test_util.run_v1_only("b/124229375")
2055  def testDynamic(self):
2056    time_steps = 8
2057    num_units = 3
2058    input_size = 5
2059    batch_size = 2
2060
2061    input_values = np.random.randn(time_steps, batch_size, input_size)
2062
2063    sequence_length = np.random.randint(0, time_steps, size=batch_size)
2064
2065    with self.session(graph=ops.Graph()) as sess:
2066      concat_inputs = array_ops.placeholder(
2067          dtypes.float32, shape=(time_steps, batch_size, input_size))
2068
2069      cell = rnn_cell.GRUCell(num_units=num_units)
2070
2071      with variable_scope.variable_scope("dynamic_scope"):
2072        outputs_dynamic, state_dynamic = rnn.dynamic_rnn(
2073            cell,
2074            inputs=concat_inputs,
2075            sequence_length=sequence_length,
2076            time_major=True,
2077            dtype=dtypes.float32)
2078
2079      feeds = {concat_inputs: input_values}
2080
2081      # Initialize
2082      variables_lib.global_variables_initializer().run(feed_dict=feeds)
2083
2084      sess.run([outputs_dynamic, state_dynamic], feed_dict=feeds)
2085
2086  def _testScope(self, factory, prefix="prefix", use_outer_scope=True):
2087    with self.session(graph=ops.Graph()):
2088      if use_outer_scope:
2089        with variable_scope.variable_scope(prefix) as scope:
2090          factory(scope)
2091      else:
2092        factory(prefix)
2093        variables_lib.global_variables_initializer()
2094
2095      # check that all the variables names starts
2096      # with the proper scope.
2097      all_vars = variables_lib.global_variables()
2098      prefix = prefix or "rnn"
2099      scope_vars = [v for v in all_vars if v.name.startswith(prefix + "/")]
2100      tf_logging.info("RNN with scope: %s (%s)" %
2101                      (prefix, "scope" if use_outer_scope else "str"))
2102      for v in scope_vars:
2103        tf_logging.info(v.name)
2104      self.assertEqual(len(scope_vars), len(all_vars))
2105
2106  @test_util.run_v1_only("b/124229375")
2107  def testDynamicScope(self):
2108    time_steps = 8
2109    num_units = 3
2110    input_size = 5
2111    batch_size = 2
2112    sequence_length = np.random.randint(0, time_steps, size=batch_size)
2113
2114    def factory(scope):
2115      concat_inputs = array_ops.placeholder(
2116          dtypes.float32, shape=(time_steps, batch_size, input_size))
2117      cell = rnn_cell.GRUCell(num_units=num_units)
2118      return rnn.dynamic_rnn(
2119          cell,
2120          inputs=concat_inputs,
2121          sequence_length=sequence_length,
2122          time_major=True,
2123          dtype=dtypes.float32,
2124          scope=scope)
2125
2126    self._testScope(factory, use_outer_scope=True)
2127    self._testScope(factory, use_outer_scope=False)
2128    self._testScope(factory, prefix=None, use_outer_scope=False)
2129
2130
2131class RawRNNTest(test.TestCase):
2132
2133  def setUp(self):
2134    self._seed = 23489
2135    np.random.seed(self._seed)
2136
2137  @test_util.run_v1_only("b/124229375")
2138  def _testRawRNN(self, max_time):
2139    with self.session(graph=ops.Graph()) as sess:
2140      batch_size = 16
2141      input_depth = 4
2142      num_units = 3
2143
2144      inputs = array_ops.placeholder(
2145          shape=(max_time, batch_size, input_depth), dtype=dtypes.float32)
2146      sequence_length = array_ops.placeholder(
2147          shape=(batch_size,), dtype=dtypes.int32)
2148      inputs_ta = tensor_array_ops.TensorArray(
2149          dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
2150      inputs_ta = inputs_ta.unstack(inputs)
2151
2152      cell = rnn_cell.LSTMCell(num_units, state_is_tuple=True)
2153
2154      def loop_fn(time_, cell_output, cell_state, unused_loop_state):
2155        emit_output = cell_output  # == None for time == 0
2156        if cell_output is None:  # time == 0
2157          next_state = cell.zero_state(batch_size, dtypes.float32)
2158        else:
2159          next_state = cell_state  # copy state through
2160        elements_finished = (time_ >= sequence_length)
2161        finished = math_ops.reduce_all(elements_finished)
2162        # For the very final iteration, we must emit a dummy input
2163        next_input = control_flow_ops.cond(
2164            finished,
2165            lambda: array_ops.zeros([batch_size, input_depth], dtype=dtypes.float32),
2166            lambda: inputs_ta.read(time_))
2167        return (elements_finished, next_input, next_state, emit_output, None)
2168
2169      reuse_scope = variable_scope.get_variable_scope()
2170
2171      outputs_ta, final_state, _ = rnn.raw_rnn(cell, loop_fn, scope=reuse_scope)
2172      outputs = outputs_ta.stack()
2173
2174      reuse_scope.reuse_variables()
2175      outputs_dynamic_rnn, final_state_dynamic_rnn = rnn.dynamic_rnn(
2176          cell,
2177          inputs,
2178          time_major=True,
2179          dtype=dtypes.float32,
2180          sequence_length=sequence_length,
2181          scope=reuse_scope)
2182
2183      variables = variables_lib.trainable_variables()
2184      gradients = gradients_impl.gradients([outputs, final_state],
2185                                           [inputs] + variables)
2186      gradients_dynamic_rnn = gradients_impl.gradients(
2187          [outputs_dynamic_rnn, final_state_dynamic_rnn], [inputs] + variables)
2188
2189      variables_lib.global_variables_initializer().run()
2190
2191      rand_input = np.random.randn(max_time, batch_size, input_depth)
2192      if max_time == 0:
2193        rand_seq_len = np.zeros(batch_size)
2194      else:
2195        rand_seq_len = np.random.randint(max_time, size=batch_size)
2196
2197      # To ensure same output lengths for dynamic_rnn and raw_rnn
2198      rand_seq_len[0] = max_time
2199
2200      (outputs_val, outputs_dynamic_rnn_val, final_state_val,
2201       final_state_dynamic_rnn_val) = sess.run(
2202           [outputs, outputs_dynamic_rnn, final_state, final_state_dynamic_rnn],
2203           feed_dict={
2204               inputs: rand_input,
2205               sequence_length: rand_seq_len
2206           })
2207
2208      self.assertAllClose(outputs_dynamic_rnn_val, outputs_val)
2209      self.assertAllClose(final_state_dynamic_rnn_val, final_state_val)
2210
2211      # NOTE: Because with 0 time steps, raw_rnn does not have shape
2212      # information about the input, it is impossible to perform
2213      # gradients comparisons as the gradients eval will fail.  So
2214      # this case skips the gradients test.
2215      if max_time > 0:
2216        self.assertEqual(len(gradients), len(gradients_dynamic_rnn))
2217        gradients_val = sess.run(
2218            gradients,
2219            feed_dict={
2220                inputs: rand_input,
2221                sequence_length: rand_seq_len
2222            })
2223        gradients_dynamic_rnn_val = sess.run(
2224            gradients_dynamic_rnn,
2225            feed_dict={
2226                inputs: rand_input,
2227                sequence_length: rand_seq_len
2228            })
2229        self.assertEqual(len(gradients_val), len(gradients_dynamic_rnn_val))
2230        input_gradients_val = gradients_val[0]
2231        input_gradients_dynamic_rnn_val = gradients_dynamic_rnn_val[0]
2232        self.assertAllClose(input_gradients_val,
2233                            input_gradients_dynamic_rnn_val)
2234        for i in range(1, len(gradients_val)):
2235          self.assertAllClose(gradients_dynamic_rnn_val[i], gradients_val[i])
2236
2237  @test_util.run_v1_only("b/124229375")
2238  def testRawRNNZeroLength(self):
2239    # NOTE: Because with 0 time steps, raw_rnn does not have shape
2240    # information about the input, it is impossible to perform
2241    # gradients comparisons as the gradients eval will fail.  So this
2242    # case skips the gradients test.
2243    self._testRawRNN(max_time=0)
2244
2245  def testRawRNN(self):
2246    self._testRawRNN(max_time=10)
2247
2248  @test_util.run_v1_only("b/124229375")
2249  def testLoopState(self):
2250    with self.session(graph=ops.Graph()):
2251      max_time = 10
2252      batch_size = 16
2253      input_depth = 4
2254      num_units = 3
2255
2256      inputs = np.random.randn(max_time, batch_size, input_depth)
2257      inputs_ta = tensor_array_ops.TensorArray(
2258          dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
2259      inputs_ta = inputs_ta.unstack(inputs)
2260
2261      cell = rnn_cell.LSTMCell(num_units, state_is_tuple=True)
2262
2263      def loop_fn(time_, cell_output, cell_state, loop_state):
2264        if cell_output is None:
2265          loop_state = constant_op.constant([0])
2266          next_state = cell.zero_state(batch_size, dtypes.float32)
2267        else:
2268          loop_state = array_ops.stack([array_ops.squeeze(loop_state) + 1])
2269          next_state = cell_state
2270        emit_output = cell_output  # == None for time == 0
2271        elements_finished = array_ops.tile([time_ >= max_time], [batch_size])
2272        finished = math_ops.reduce_all(elements_finished)
2273        # For the very final iteration, we must emit a dummy input
2274        next_input = control_flow_ops.cond(
2275            finished,
2276            lambda: array_ops.zeros([batch_size, input_depth], dtype=dtypes.float32),
2277            lambda: inputs_ta.read(time_))
2278        return (elements_finished, next_input, next_state, emit_output,
2279                loop_state)
2280
2281      r = rnn.raw_rnn(cell, loop_fn)
2282      loop_state = r[-1]
2283      self.assertEqual([10], self.evaluate(loop_state))
2284
2285  @test_util.run_v1_only("b/124229375")
2286  def testLoopStateWithTensorArray(self):
2287    with self.session(graph=ops.Graph()):
2288      max_time = 4
2289      batch_size = 16
2290      input_depth = 4
2291      num_units = 3
2292
2293      inputs = np.random.randn(max_time, batch_size, input_depth)
2294      inputs_ta = tensor_array_ops.TensorArray(
2295          dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
2296      inputs_ta = inputs_ta.unstack(inputs)
2297
2298      cell = rnn_cell.LSTMCell(num_units, state_is_tuple=True)
2299
2300      def loop_fn(time_, cell_output, cell_state, loop_state):
2301        if cell_output is None:
2302          loop_state = tensor_array_ops.TensorArray(
2303              dynamic_size=True,
2304              size=0,
2305              dtype=dtypes.int32,
2306              clear_after_read=False)
2307          loop_state = loop_state.write(0, 1)
2308          next_state = cell.zero_state(batch_size, dtypes.float32)
2309        else:
2310          loop_state = loop_state.write(time_,
2311                                        loop_state.read(time_ - 1) + time_)
2312          next_state = cell_state
2313        emit_output = cell_output  # == None for time == 0
2314        elements_finished = array_ops.tile([time_ >= max_time], [batch_size])
2315        finished = math_ops.reduce_all(elements_finished)
2316        # For the very final iteration, we must emit a dummy input
2317        next_input = control_flow_ops.cond(
2318            finished,
2319            lambda: array_ops.zeros([batch_size, input_depth], dtype=dtypes.float32),
2320            lambda: inputs_ta.read(time_))
2321        return (elements_finished, next_input, next_state, emit_output,
2322                loop_state)
2323
2324      r = rnn.raw_rnn(cell, loop_fn)
2325      loop_state = r[-1]
2326      loop_state = loop_state.stack()
2327      self.assertAllEqual([1, 2, 2 + 2, 4 + 3, 7 + 4], loop_state)
2328
2329  @test_util.run_v1_only("b/124229375")
2330  def testEmitDifferentStructureThanCellOutput(self):
2331    with self.session(graph=ops.Graph()) as sess:
2332      max_time = 10
2333      batch_size = 16
2334      input_depth = 4
2335      num_units = 3
2336
2337      inputs = np.random.randn(max_time, batch_size, input_depth)
2338      inputs_ta = tensor_array_ops.TensorArray(
2339          dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
2340      inputs_ta = inputs_ta.unstack(inputs)
2341      # Verify emit shapes may be unknown by feeding a placeholder that
2342      # determines an emit shape.
2343      unknown_dim = array_ops.placeholder(dtype=dtypes.int32)
2344
2345      cell = rnn_cell.LSTMCell(num_units, state_is_tuple=True)
2346
2347      def loop_fn(time_, cell_output, cell_state, _):
2348        if cell_output is None:
2349          emit_output = (array_ops.zeros([2, 3], dtype=dtypes.int32),
2350                         array_ops.zeros([unknown_dim], dtype=dtypes.int64))
2351          next_state = cell.zero_state(batch_size, dtypes.float32)
2352        else:
2353          emit_output = (array_ops.ones([batch_size, 2, 3], dtype=dtypes.int32),
2354                         array_ops.ones(
2355                             [batch_size, unknown_dim], dtype=dtypes.int64))
2356          next_state = cell_state
2357        elements_finished = array_ops.tile([time_ >= max_time], [batch_size])
2358        finished = math_ops.reduce_all(elements_finished)
2359        # For the very final iteration, we must emit a dummy input
2360        next_input = control_flow_ops.cond(
2361            finished,
2362            lambda: array_ops.zeros([batch_size, input_depth], dtype=dtypes.float32),
2363            lambda: inputs_ta.read(time_))
2364        return (elements_finished, next_input, next_state, emit_output, None)
2365
2366      r = rnn.raw_rnn(cell, loop_fn)
2367      output_ta = r[0]
2368      self.assertEqual(2, len(output_ta))
2369      self.assertEqual([dtypes.int32, dtypes.int64],
2370                       [ta.dtype for ta in output_ta])
2371      output = [ta.stack() for ta in output_ta]
2372      output_vals = sess.run(output, feed_dict={unknown_dim: 1})
2373      self.assertAllEqual(
2374          np.ones((max_time, batch_size, 2, 3), np.int32), output_vals[0])
2375      self.assertAllEqual(
2376          np.ones((max_time, batch_size, 1), np.int64), output_vals[1])
2377
2378  def _testScope(self, factory, prefix="prefix", use_outer_scope=True):
2379    with self.session(graph=ops.Graph()):
2380      if use_outer_scope:
2381        with variable_scope.variable_scope(prefix) as scope:
2382          factory(scope)
2383      else:
2384        factory(prefix)
2385        variables_lib.global_variables_initializer()
2386
2387      # check that all the variables names starts
2388      # with the proper scope.
2389      all_vars = variables_lib.global_variables()
2390      prefix = prefix or "rnn"
2391      scope_vars = [v for v in all_vars if v.name.startswith(prefix + "/")]
2392      tf_logging.info("RNN with scope: %s (%s)" %
2393                      (prefix, "scope" if use_outer_scope else "str"))
2394      for v in scope_vars:
2395        tf_logging.info(v.name)
2396      self.assertEqual(len(scope_vars), len(all_vars))
2397
2398  @test_util.run_v1_only("b/124229375")
2399  def testRawRNNScope(self):
2400    max_time = 10
2401    batch_size = 16
2402    input_depth = 4
2403    num_units = 3
2404
2405    def factory(scope):
2406      inputs = array_ops.placeholder(
2407          shape=(max_time, batch_size, input_depth), dtype=dtypes.float32)
2408      sequence_length = array_ops.placeholder(
2409          shape=(batch_size,), dtype=dtypes.int32)
2410      inputs_ta = tensor_array_ops.TensorArray(
2411          dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
2412      inputs_ta = inputs_ta.unstack(inputs)
2413
2414      cell = rnn_cell.LSTMCell(num_units, state_is_tuple=True)
2415
2416      def loop_fn(time_, cell_output, cell_state, unused_loop_state):
2417        emit_output = cell_output  # == None for time == 0
2418        if cell_output is None:  # time == 0
2419          next_state = cell.zero_state(batch_size, dtypes.float32)
2420        else:
2421          next_state = cell_state
2422
2423        elements_finished = (time_ >= sequence_length)
2424        finished = math_ops.reduce_all(elements_finished)
2425        # For the very final iteration, we must emit a dummy input
2426        next_input = control_flow_ops.cond(
2427            finished,
2428            lambda: array_ops.zeros([batch_size, input_depth], dtype=dtypes.float32),
2429            lambda: inputs_ta.read(time_))
2430        return (elements_finished, next_input, next_state, emit_output, None)
2431
2432      return rnn.raw_rnn(cell, loop_fn, scope=scope)
2433
2434    self._testScope(factory, use_outer_scope=True)
2435    self._testScope(factory, use_outer_scope=False)
2436    self._testScope(factory, prefix=None, use_outer_scope=False)
2437
2438
2439class DeviceWrapperCell(rnn_cell.RNNCell):
2440  """Class to ensure cell calculation happens on a specific device."""
2441
2442  def __init__(self, cell, device):
2443    self._cell = cell
2444    self._device = device
2445
2446  @property
2447  def output_size(self):
2448    return self._cell.output_size
2449
2450  @property
2451  def state_size(self):
2452    return self._cell.state_size
2453
2454  def __call__(self, input_, state, scope=None):
2455    if self._device is not None:
2456      with ops.device(self._device):
2457        return self._cell(input_, state, scope=scope)
2458    else:
2459      return self._cell(input_, state, scope=scope)
2460
2461
2462class TensorArrayOnCorrectDeviceTest(test.TestCase):
2463
2464  def _execute_rnn_on(self,
2465                      rnn_device=None,
2466                      cell_device=None,
2467                      input_device=None):
2468    batch_size = 3
2469    time_steps = 7
2470    input_size = 5
2471    num_units = 10
2472
2473    cell = rnn_cell.LSTMCell(num_units, use_peepholes=True)
2474    gpu_cell = DeviceWrapperCell(cell, cell_device)
2475    inputs = np.random.randn(batch_size, time_steps, input_size).astype(
2476        np.float32)
2477    sequence_length = np.random.randint(0, time_steps, size=batch_size)
2478
2479    if input_device is not None:
2480      with ops.device(input_device):
2481        inputs = constant_op.constant(inputs)
2482
2483    if rnn_device is not None:
2484      with ops.device(rnn_device):
2485        outputs, _ = rnn.dynamic_rnn(
2486            gpu_cell,
2487            inputs,
2488            sequence_length=sequence_length,
2489            dtype=dtypes.float32)
2490    else:
2491      outputs, _ = rnn.dynamic_rnn(
2492          gpu_cell,
2493          inputs,
2494          sequence_length=sequence_length,
2495          dtype=dtypes.float32)
2496
2497    with self.session() as sess:
2498      opts = config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
2499      run_metadata = config_pb2.RunMetadata()
2500      variables_lib.global_variables_initializer().run()
2501      sess.run(outputs, options=opts, run_metadata=run_metadata)
2502
2503    return run_metadata
2504
2505  def _retrieve_cpu_gpu_stats(self, run_metadata):
2506    cpu_stats = None
2507    gpu_stats = None
2508    step_stats = run_metadata.step_stats
2509    for ds in step_stats.dev_stats:
2510      if "cpu:0" in ds.device[-5:].lower():
2511        cpu_stats = ds.node_stats
2512      if "gpu:0" == ds.device[-5:].lower():
2513        gpu_stats = ds.node_stats
2514    return cpu_stats, gpu_stats
2515
2516  @test_util.run_v1_only("b/124229375")
2517  def testRNNOnCPUCellOnGPU(self):
2518    if not test.is_gpu_available():
2519      return  # Test requires access to a GPU
2520
2521    gpu_dev = test.gpu_device_name()
2522    run_metadata = self._execute_rnn_on(
2523        rnn_device="/cpu:0", cell_device=gpu_dev)
2524    cpu_stats, gpu_stats = self._retrieve_cpu_gpu_stats(run_metadata)
2525
2526    def _assert_in(op_str, in_stats, out_stats):
2527      self.assertTrue(any(op_str in s.node_name for s in in_stats))
2528      self.assertFalse(any(op_str in s.node_name for s in out_stats))
2529
2530    # Writes happen at output of RNN cell
2531    _assert_in("TensorArrayWrite", gpu_stats, cpu_stats)
2532    # Gather happens on final TensorArray
2533    _assert_in("TensorArrayGather", gpu_stats, cpu_stats)
2534    # Reads happen at input to RNN cell
2535    _assert_in("TensorArrayRead", cpu_stats, gpu_stats)
2536    # Scatters happen to get initial input into TensorArray
2537    _assert_in("TensorArrayScatter", cpu_stats, gpu_stats)
2538
2539  @test_util.run_v1_only("b/124229375")
2540  def testRNNOnCPUCellOnCPU(self):
2541    if not test.is_gpu_available():
2542      return  # Test requires access to a GPU
2543
2544    gpu_dev = test.gpu_device_name()
2545    run_metadata = self._execute_rnn_on(
2546        rnn_device="/cpu:0", cell_device="/cpu:0", input_device=gpu_dev)
2547    cpu_stats, gpu_stats = self._retrieve_cpu_gpu_stats(run_metadata)
2548
2549    def _assert_in(op_str, in_stats, out_stats):
2550      self.assertTrue(any(op_str in s.node_name for s in in_stats))
2551      self.assertFalse(any(op_str in s.node_name for s in out_stats))
2552
2553    # All TensorArray operations happen on CPU
2554    _assert_in("TensorArray", cpu_stats, gpu_stats)
2555
2556  @test_util.run_v1_only("b/124229375")
2557  def testInputOnGPUCellNotDeclared(self):
2558    if not test.is_gpu_available():
2559      return  # Test requires access to a GPU
2560
2561    gpu_dev = test.gpu_device_name()
2562    run_metadata = self._execute_rnn_on(input_device=gpu_dev)
2563    cpu_stats, gpu_stats = self._retrieve_cpu_gpu_stats(run_metadata)
2564
2565    def _assert_in(op_str, in_stats, out_stats):
2566      self.assertTrue(any(op_str in s.node_name for s in in_stats))
2567      self.assertFalse(any(op_str in s.node_name for s in out_stats))
2568
2569    # Everything happens on GPU
2570    _assert_in("TensorArray", gpu_stats, cpu_stats)
2571
2572
2573class RNNCellTest(test.TestCase, parameterized.TestCase):
2574
2575  @test_util.run_v1_only("b/124229375")
2576  def testBasicRNNCell(self):
2577    with self.cached_session() as sess:
2578      with variable_scope.variable_scope(
2579          "root", initializer=init_ops.constant_initializer(0.5)):
2580        x = array_ops.zeros([1, 2])
2581        m = array_ops.zeros([1, 2])
2582        cell = rnn_cell_impl.BasicRNNCell(2)
2583        g, _ = cell(x, m)
2584        self.assertEqual([
2585            "root/basic_rnn_cell/%s:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
2586            "root/basic_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME
2587        ], [v.name for v in cell.trainable_variables])
2588        self.assertFalse(cell.non_trainable_variables)
2589        sess.run([variables_lib.global_variables_initializer()])
2590        res = sess.run([g], {
2591            x: np.array([[1., 1.]]),
2592            m: np.array([[0.1, 0.1]])
2593        })
2594        self.assertEqual(res[0].shape, (1, 2))
2595
2596  @test_util.run_v1_only("b/124229375")
2597  def testBasicRNNCellNotTrainable(self):
2598    with self.cached_session() as sess:
2599
2600      def not_trainable_getter(getter, *args, **kwargs):
2601        kwargs["trainable"] = False
2602        return getter(*args, **kwargs)
2603
2604      with variable_scope.variable_scope(
2605          "root",
2606          initializer=init_ops.constant_initializer(0.5),
2607          custom_getter=not_trainable_getter):
2608        x = array_ops.zeros([1, 2])
2609        m = array_ops.zeros([1, 2])
2610        cell = rnn_cell_impl.BasicRNNCell(2)
2611        g, _ = cell(x, m)
2612        self.assertFalse(cell.trainable_variables)
2613        self.assertEqual([
2614            "root/basic_rnn_cell/%s:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
2615            "root/basic_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME
2616        ], [v.name for v in cell.non_trainable_variables])
2617        sess.run([variables_lib.global_variables_initializer()])
2618        res = sess.run([g], {
2619            x: np.array([[1., 1.]]),
2620            m: np.array([[0.1, 0.1]])
2621        })
2622        self.assertEqual(res[0].shape, (1, 2))
2623
2624  @test_util.run_v1_only("b/124229375")
2625  def testGRUCell(self):
2626    with self.cached_session() as sess:
2627      with variable_scope.variable_scope(
2628          "root", initializer=init_ops.constant_initializer(0.5)):
2629        x = array_ops.zeros([1, 2])
2630        m = array_ops.zeros([1, 2])
2631        g, _ = rnn_cell_impl.GRUCell(2)(x, m)
2632        sess.run([variables_lib.global_variables_initializer()])
2633        res = sess.run([g], {
2634            x: np.array([[1., 1.]]),
2635            m: np.array([[0.1, 0.1]])
2636        })
2637        # Smoke test
2638        self.assertAllClose(res[0], [[0.175991, 0.175991]])
2639      with variable_scope.variable_scope(
2640          "other", initializer=init_ops.constant_initializer(0.5)):
2641        # Test GRUCell with input_size != num_units.
2642        x = array_ops.zeros([1, 3])
2643        m = array_ops.zeros([1, 2])
2644        g, _ = rnn_cell_impl.GRUCell(2)(x, m)
2645        sess.run([variables_lib.global_variables_initializer()])
2646        res = sess.run([g], {
2647            x: np.array([[1., 1., 1.]]),
2648            m: np.array([[0.1, 0.1]])
2649        })
2650        # Smoke test
2651        self.assertAllClose(res[0], [[0.156736, 0.156736]])
2652
2653  @test_util.run_v1_only("b/124229375")
2654  def testBasicLSTMCell(self):
2655    for dtype in [dtypes.float16, dtypes.float32]:
2656      np_dtype = dtype.as_numpy_dtype
2657      with self.session(graph=ops.Graph()) as sess:
2658        with variable_scope.variable_scope(
2659            "root", initializer=init_ops.constant_initializer(0.5)):
2660          x = array_ops.zeros([1, 2], dtype=dtype)
2661          m = array_ops.zeros([1, 8], dtype=dtype)
2662          cell = rnn_cell_impl.MultiRNNCell(
2663              [
2664                  rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
2665                  for _ in range(2)
2666              ],
2667              state_is_tuple=False)
2668          self.assertEqual(cell.dtype, None)
2669          self.assertIn("cell-0", cell._trackable_children())
2670          self.assertIn("cell-1", cell._trackable_children())
2671          cell.get_config()  # Should not throw an error
2672          g, out_m = cell(x, m)
2673          # Layer infers the input type.
2674          self.assertEqual(cell.dtype, dtype.name)
2675          expected_variable_names = [
2676              "root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0" %
2677              rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
2678              "root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0" %
2679              rnn_cell_impl._BIAS_VARIABLE_NAME,
2680              "root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" %
2681              rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
2682              "root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" %
2683              rnn_cell_impl._BIAS_VARIABLE_NAME
2684          ]
2685          self.assertEqual(expected_variable_names,
2686                           [v.name for v in cell.trainable_variables])
2687          self.assertFalse(cell.non_trainable_variables)
2688          sess.run([variables_lib.global_variables_initializer()])
2689          res = sess.run([g, out_m], {
2690              x: np.array([[1., 1.]]),
2691              m: 0.1 * np.ones([1, 8])
2692          })
2693          self.assertEqual(len(res), 2)
2694          variables = variables_lib.global_variables()
2695          self.assertEqual(expected_variable_names, [v.name for v in variables])
2696          # The numbers in results were not calculated, this is just a
2697          # smoke test.
2698          self.assertAllClose(res[0], np.array(
2699              [[0.240, 0.240]], dtype=np_dtype), 1e-2)
2700          expected_mem = np.array(
2701              [[0.689, 0.689, 0.448, 0.448, 0.398, 0.398, 0.240, 0.240]],
2702              dtype=np_dtype)
2703          self.assertAllClose(res[1], expected_mem, 1e-2)
2704        with variable_scope.variable_scope(
2705            "other", initializer=init_ops.constant_initializer(0.5)):
2706          # Test BasicLSTMCell with input_size != num_units.
2707          x = array_ops.zeros([1, 3], dtype=dtype)
2708          m = array_ops.zeros([1, 4], dtype=dtype)
2709          g, out_m = rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)(x, m)
2710          sess.run([variables_lib.global_variables_initializer()])
2711          res = sess.run(
2712              [g, out_m], {
2713                  x: np.array([[1., 1., 1.]], dtype=np_dtype),
2714                  m: 0.1 * np.ones([1, 4], dtype=np_dtype)
2715              })
2716          self.assertEqual(len(res), 2)
2717
2718  @test_util.run_v1_only("b/124229375")
2719  def testBasicLSTMCellDimension0Error(self):
2720    """Tests that dimension 0 in both(x and m) shape must be equal."""
2721    with self.cached_session() as sess:
2722      with variable_scope.variable_scope(
2723          "root", initializer=init_ops.constant_initializer(0.5)):
2724        num_units = 2
2725        state_size = num_units * 2
2726        batch_size = 3
2727        input_size = 4
2728        x = array_ops.zeros([batch_size, input_size])
2729        m = array_ops.zeros([batch_size - 1, state_size])
2730        with self.assertRaises(ValueError):
2731          g, out_m = rnn_cell_impl.BasicLSTMCell(
2732              num_units, state_is_tuple=False)(x, m)
2733          sess.run([variables_lib.global_variables_initializer()])
2734          sess.run(
2735              [g, out_m], {
2736                  x: 1 * np.ones([batch_size, input_size]),
2737                  m: 0.1 * np.ones([batch_size - 1, state_size])
2738              })
2739
2740  def testBasicLSTMCellStateSizeError(self):
2741    """Tests that state_size must be num_units * 2."""
2742    with self.cached_session() as sess:
2743      with variable_scope.variable_scope(
2744          "root", initializer=init_ops.constant_initializer(0.5)):
2745        num_units = 2
2746        state_size = num_units * 3  # state_size must be num_units * 2
2747        batch_size = 3
2748        input_size = 4
2749        x = array_ops.zeros([batch_size, input_size])
2750        m = array_ops.zeros([batch_size, state_size])
2751        with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
2752          g, out_m = rnn_cell_impl.BasicLSTMCell(
2753              num_units, state_is_tuple=False)(x, m)
2754          sess.run([variables_lib.global_variables_initializer()])
2755          sess.run(
2756              [g, out_m], {
2757                  x: 1 * np.ones([batch_size, input_size]),
2758                  m: 0.1 * np.ones([batch_size, state_size])
2759              })
2760
2761  @test_util.run_v1_only("b/124229375")
2762  def testBasicLSTMCellStateTupleType(self):
2763    with self.cached_session():
2764      with variable_scope.variable_scope(
2765          "root", initializer=init_ops.constant_initializer(0.5)):
2766        x = array_ops.zeros([1, 2])
2767        m0 = (array_ops.zeros([1, 2]),) * 2
2768        m1 = (array_ops.zeros([1, 2]),) * 2
2769        cell = rnn_cell_impl.MultiRNNCell(
2770            [rnn_cell_impl.BasicLSTMCell(2) for _ in range(2)],
2771            state_is_tuple=True)
2772        self.assertTrue(isinstance(cell.state_size, tuple))
2773        self.assertTrue(
2774            isinstance(cell.state_size[0], rnn_cell_impl.LSTMStateTuple))
2775        self.assertTrue(
2776            isinstance(cell.state_size[1], rnn_cell_impl.LSTMStateTuple))
2777
2778        # Pass in regular tuples
2779        _, (out_m0, out_m1) = cell(x, (m0, m1))
2780        self.assertTrue(isinstance(out_m0, rnn_cell_impl.LSTMStateTuple))
2781        self.assertTrue(isinstance(out_m1, rnn_cell_impl.LSTMStateTuple))
2782
2783        # Pass in LSTMStateTuples
2784        variable_scope.get_variable_scope().reuse_variables()
2785        zero_state = cell.zero_state(1, dtypes.float32)
2786        self.assertTrue(isinstance(zero_state, tuple))
2787        self.assertTrue(isinstance(zero_state[0], rnn_cell_impl.LSTMStateTuple))
2788        self.assertTrue(isinstance(zero_state[1], rnn_cell_impl.LSTMStateTuple))
2789        _, (out_m0, out_m1) = cell(x, zero_state)
2790        self.assertTrue(isinstance(out_m0, rnn_cell_impl.LSTMStateTuple))
2791        self.assertTrue(isinstance(out_m1, rnn_cell_impl.LSTMStateTuple))
2792
2793  @test_util.run_v1_only("b/124229375")
2794  def testBasicLSTMCellWithStateTuple(self):
2795    with self.cached_session() as sess:
2796      with variable_scope.variable_scope(
2797          "root", initializer=init_ops.constant_initializer(0.5)):
2798        x = array_ops.zeros([1, 2])
2799        m0 = array_ops.zeros([1, 4])
2800        m1 = array_ops.zeros([1, 4])
2801        cell = rnn_cell_impl.MultiRNNCell(
2802            [
2803                rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
2804                for _ in range(2)
2805            ],
2806            state_is_tuple=True)
2807        g, (out_m0, out_m1) = cell(x, (m0, m1))
2808        sess.run([variables_lib.global_variables_initializer()])
2809        res = sess.run(
2810            [g, out_m0, out_m1], {
2811                x: np.array([[1., 1.]]),
2812                m0: 0.1 * np.ones([1, 4]),
2813                m1: 0.1 * np.ones([1, 4])
2814            })
2815        self.assertEqual(len(res), 3)
2816        # The numbers in results were not calculated, this is just a smoke test.
2817        # Note, however, these values should match the original
2818        # version having state_is_tuple=False.
2819        self.assertAllClose(res[0], [[0.24024698, 0.24024698]])
2820        expected_mem0 = np.array(
2821            [[0.68967271, 0.68967271, 0.44848421, 0.44848421]])
2822        expected_mem1 = np.array(
2823            [[0.39897051, 0.39897051, 0.24024698, 0.24024698]])
2824        self.assertAllClose(res[1], expected_mem0)
2825        self.assertAllClose(res[2], expected_mem1)
2826
2827  @test_util.run_v1_only("b/124229375")
2828  def testLSTMCell(self):
2829    with self.cached_session() as sess:
2830      num_units = 8
2831      num_proj = 6
2832      state_size = num_units + num_proj
2833      batch_size = 3
2834      input_size = 2
2835      with variable_scope.variable_scope(
2836          "root", initializer=init_ops.constant_initializer(0.5)):
2837        x = array_ops.zeros([batch_size, input_size])
2838        m = array_ops.zeros([batch_size, state_size])
2839        cell = rnn_cell_impl.LSTMCell(
2840            num_units=num_units,
2841            num_proj=num_proj,
2842            forget_bias=1.0,
2843            state_is_tuple=False)
2844        output, state = cell(x, m)
2845        sess.run([variables_lib.global_variables_initializer()])
2846        res = sess.run(
2847            [output, state], {
2848                x: np.array([[1., 1.], [2., 2.], [3., 3.]]),
2849                m: 0.1 * np.ones((batch_size, state_size))
2850            })
2851        self.assertEqual(len(res), 2)
2852        # The numbers in results were not calculated, this is mostly just a
2853        # smoke test.
2854        self.assertEqual(res[0].shape, (batch_size, num_proj))
2855        self.assertEqual(res[1].shape, (batch_size, state_size))
2856        # Different inputs so different outputs and states
2857        for i in range(1, batch_size):
2858          self.assertTrue(
2859              float(np.linalg.norm((res[0][0, :] - res[0][i, :]))) > 1e-6)
2860          self.assertTrue(
2861              float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) > 1e-6)
2862
2863  @test_util.run_v1_only("b/124229375")
2864  def testLSTMCellVariables(self):
2865    with self.cached_session():
2866      num_units = 8
2867      num_proj = 6
2868      state_size = num_units + num_proj
2869      batch_size = 3
2870      input_size = 2
2871      with variable_scope.variable_scope(
2872          "root", initializer=init_ops.constant_initializer(0.5)):
2873        x = array_ops.zeros([batch_size, input_size])
2874        m = array_ops.zeros([batch_size, state_size])
2875        cell = rnn_cell_impl.LSTMCell(
2876            num_units=num_units,
2877            num_proj=num_proj,
2878            forget_bias=1.0,
2879            state_is_tuple=False)
2880        cell(x, m)  # Execute to create variables
2881      variables = variables_lib.global_variables()
2882      self.assertEqual(variables[0].op.name, "root/lstm_cell/kernel")
2883      self.assertEqual(variables[1].op.name, "root/lstm_cell/bias")
2884      self.assertEqual(variables[2].op.name, "root/lstm_cell/projection/kernel")
2885
2886  @test_util.run_in_graph_and_eager_modes
2887  def testWrapperCheckpointing(self):
2888    for wrapper_type in [
2889        rnn_cell_impl.DropoutWrapper,
2890        rnn_cell_impl.ResidualWrapper,
2891        lambda cell: rnn_cell_impl.MultiRNNCell([cell])]:
2892      cell = rnn_cell_impl.BasicRNNCell(1)
2893      wrapper = wrapper_type(cell)
2894      wrapper(array_ops.ones([1, 1]),
2895              state=wrapper.zero_state(batch_size=1, dtype=dtypes.float32))
2896      self.evaluate([v.initializer for v in cell.variables])
2897      checkpoint = trackable_utils.Checkpoint(wrapper=wrapper)
2898      prefix = os.path.join(self.get_temp_dir(), "ckpt")
2899      self.evaluate(cell._bias.assign([40.]))
2900      save_path = checkpoint.save(prefix)
2901      self.evaluate(cell._bias.assign([0.]))
2902      checkpoint.restore(save_path).assert_consumed().run_restore_ops()
2903      self.assertAllEqual([40.], self.evaluate(cell._bias))
2904
2905  @test_util.run_in_graph_and_eager_modes
2906  def testResidualWrapper(self):
2907    wrapper_type = rnn_cell_impl.ResidualWrapper
2908    x = ops.convert_to_tensor(np.array([[1., 1., 1.]]))
2909    m = ops.convert_to_tensor(np.array([[0.1, 0.1, 0.1]]))
2910    base_cell = rnn_cell_impl.GRUCell(
2911        3, kernel_initializer=init_ops.constant_initializer(0.5),
2912        bias_initializer=init_ops.constant_initializer(0.5))
2913    g, m_new = base_cell(x, m)
2914    wrapper_object = wrapper_type(base_cell)
2915    wrapper_object.get_config()  # Should not throw an error
2916
2917    self.assertIn("cell", wrapper_object._trackable_children())
2918    self.assertIs(wrapper_object._trackable_children()["cell"], base_cell)
2919
2920    g_res, m_new_res = wrapper_object(x, m)
2921    self.evaluate([variables_lib.global_variables_initializer()])
2922    res = self.evaluate([g, g_res, m_new, m_new_res])
2923    # Residual connections
2924    self.assertAllClose(res[1], res[0] + [1., 1., 1.])
2925    # States are left untouched
2926    self.assertAllClose(res[2], res[3])
2927
2928  @test_util.run_in_graph_and_eager_modes
2929  def testResidualWrapperWithSlice(self):
2930    wrapper_type = rnn_cell_impl.ResidualWrapper
2931    x = ops.convert_to_tensor(np.array([[1., 1., 1., 1., 1.]]))
2932    m = ops.convert_to_tensor(np.array([[0.1, 0.1, 0.1]]))
2933    base_cell = rnn_cell_impl.GRUCell(
2934        3, kernel_initializer=init_ops.constant_initializer(0.5),
2935        bias_initializer=init_ops.constant_initializer(0.5))
2936    g, m_new = base_cell(x, m)
2937
2938    def residual_with_slice_fn(inp, out):
2939      inp_sliced = array_ops.slice(inp, [0, 0], [-1, 3])
2940      return inp_sliced + out
2941
2942    g_res, m_new_res = wrapper_type(
2943        base_cell, residual_with_slice_fn)(x, m)
2944    self.evaluate([variables_lib.global_variables_initializer()])
2945    res_g, res_g_res, res_m_new, res_m_new_res = self.evaluate(
2946        [g, g_res, m_new, m_new_res])
2947    # Residual connections
2948    self.assertAllClose(res_g_res, res_g + [1., 1., 1.])
2949    # States are left untouched
2950    self.assertAllClose(res_m_new, res_m_new_res)
2951
2952  def testDeviceWrapper(self):
2953    wrapper_type = rnn_cell_impl.DeviceWrapper
2954    x = array_ops.zeros([1, 3])
2955    m = array_ops.zeros([1, 3])
2956    cell = rnn_cell_impl.GRUCell(3)
2957    wrapped_cell = wrapper_type(cell, "/cpu:0")
2958    wrapped_cell.get_config()  # Should not throw an error
2959    self.assertEqual(wrapped_cell._trackable_children()["cell"], cell)
2960
2961    outputs, _ = wrapped_cell(x, m)
2962    self.assertIn("cpu:0", outputs.device.lower())
2963
2964  def _retrieve_cpu_gpu_stats(self, run_metadata):
2965    cpu_stats = None
2966    gpu_stats = None
2967    step_stats = run_metadata.step_stats
2968    for ds in step_stats.dev_stats:
2969      if "cpu:0" in ds.device[-5:].lower():
2970        cpu_stats = ds.node_stats
2971      if "gpu:0" == ds.device[-5:].lower():
2972        gpu_stats = ds.node_stats
2973    return cpu_stats, gpu_stats
2974
2975  @test_util.run_v1_only("b/124229375")
2976  def testDeviceWrapperDynamicExecutionNodesAreAllProperlyLocated(self):
2977    if not test.is_gpu_available():
2978      # Can't perform this test w/o a GPU
2979      return
2980
2981    gpu_dev = test.gpu_device_name()
2982    with self.session() as sess:
2983      with variable_scope.variable_scope(
2984          "root", initializer=init_ops.constant_initializer(0.5)):
2985        x = array_ops.zeros([1, 1, 3])
2986        cell = rnn_cell_impl.DeviceWrapper(rnn_cell_impl.GRUCell(3), gpu_dev)
2987        with ops.device("/cpu:0"):
2988          outputs, _ = rnn.dynamic_rnn(
2989              cell=cell, inputs=x, dtype=dtypes.float32)
2990        run_metadata = config_pb2.RunMetadata()
2991        opts = config_pb2.RunOptions(
2992            trace_level=config_pb2.RunOptions.FULL_TRACE)
2993
2994        sess.run([variables_lib.global_variables_initializer()])
2995        _ = sess.run(outputs, options=opts, run_metadata=run_metadata)
2996
2997      cpu_stats, gpu_stats = self._retrieve_cpu_gpu_stats(run_metadata)
2998      self.assertFalse([s for s in cpu_stats if "gru_cell" in s.node_name])
2999      self.assertTrue([s for s in gpu_stats if "gru_cell" in s.node_name])
3000
3001  @test_util.run_v1_only("b/124229375")
3002  def testMultiRNNCell(self):
3003    with self.cached_session() as sess:
3004      with variable_scope.variable_scope(
3005          "root", initializer=init_ops.constant_initializer(0.5)):
3006        x = array_ops.zeros([1, 2])
3007        m = array_ops.zeros([1, 4])
3008        multi_rnn_cell = rnn_cell_impl.MultiRNNCell(
3009            [rnn_cell_impl.GRUCell(2) for _ in range(2)],
3010            state_is_tuple=False)
3011        _, ml = multi_rnn_cell(x, m)
3012        sess.run([variables_lib.global_variables_initializer()])
3013        res = sess.run(ml, {
3014            x: np.array([[1., 1.]]),
3015            m: np.array([[0.1, 0.1, 0.1, 0.1]])
3016        })
3017        # The numbers in results were not calculated, this is just a smoke test.
3018        self.assertAllClose(res, [[0.175991, 0.175991, 0.13248, 0.13248]])
3019        self.assertEqual(len(multi_rnn_cell.weights), 2 * 4)
3020        self.assertTrue(
3021            [x.dtype == dtypes.float32 for x in multi_rnn_cell.weights])
3022
3023  @test_util.run_v1_only("b/124229375")
3024  def testMultiRNNCellWithStateTuple(self):
3025    with self.cached_session() as sess:
3026      with variable_scope.variable_scope(
3027          "root", initializer=init_ops.constant_initializer(0.5)):
3028        x = array_ops.zeros([1, 2])
3029        m_bad = array_ops.zeros([1, 4])
3030        m_good = (array_ops.zeros([1, 2]), array_ops.zeros([1, 2]))
3031
3032        # Test incorrectness of state
3033        with self.assertRaisesRegex(ValueError, "Expected state .* a tuple"):
3034          rnn_cell_impl.MultiRNNCell(
3035              [rnn_cell_impl.GRUCell(2) for _ in range(2)],
3036              state_is_tuple=True)(x, m_bad)
3037
3038        _, ml = rnn_cell_impl.MultiRNNCell(
3039            [rnn_cell_impl.GRUCell(2) for _ in range(2)],
3040            state_is_tuple=True)(x, m_good)
3041
3042        sess.run([variables_lib.global_variables_initializer()])
3043        res = sess.run(
3044            ml, {
3045                x: np.array([[1., 1.]]),
3046                m_good[0]: np.array([[0.1, 0.1]]),
3047                m_good[1]: np.array([[0.1, 0.1]])
3048            })
3049
3050        # The numbers in results were not calculated, this is just a
3051        # smoke test.  However, these numbers should match those of
3052        # the test testMultiRNNCell.
3053        self.assertAllClose(res[0], [[0.175991, 0.175991]])
3054        self.assertAllClose(res[1], [[0.13248, 0.13248]])
3055
3056  def testDeviceWrapperSerialization(self):
3057    wrapper_cls = rnn_cell_impl.DeviceWrapper
3058    cell = rnn_cell_impl.LSTMCell(10)
3059    wrapper = wrapper_cls(cell, "/cpu:0")
3060    config = wrapper.get_config()
3061
3062    # Replace the cell in the config with real cell instance to work around the
3063    # reverse keras dependency issue.
3064    config_copy = config.copy()
3065    config_copy["cell"] = rnn_cell_impl.LSTMCell.from_config(
3066        config_copy["cell"]["config"])
3067    reconstructed_wrapper = wrapper_cls.from_config(config_copy)
3068    self.assertDictEqual(config, reconstructed_wrapper.get_config())
3069    self.assertIsInstance(reconstructed_wrapper, wrapper_cls)
3070
3071  def testResidualWrapperSerialization(self):
3072    wrapper_cls = rnn_cell_impl.ResidualWrapper
3073    cell = rnn_cell_impl.LSTMCell(10)
3074    wrapper = wrapper_cls(cell)
3075    config = wrapper.get_config()
3076
3077    # Replace the cell in the config with real cell instance to work around the
3078    # reverse keras dependency issue.
3079    config_copy = config.copy()
3080    config_copy["cell"] = rnn_cell_impl.LSTMCell.from_config(
3081        config_copy["cell"]["config"])
3082    reconstructed_wrapper = wrapper_cls.from_config(config_copy)
3083    self.assertDictEqual(config, reconstructed_wrapper.get_config())
3084    self.assertIsInstance(reconstructed_wrapper, wrapper_cls)
3085
3086    wrapper = wrapper_cls(cell, residual_fn=lambda i, o: i + i + o)
3087    config = wrapper.get_config()
3088
3089    config_copy = config.copy()
3090    config_copy["cell"] = rnn_cell_impl.LSTMCell.from_config(
3091        config_copy["cell"]["config"])
3092    reconstructed_wrapper = wrapper_cls.from_config(config_copy)
3093    # Assert the reconstructed function will perform the math correctly.
3094    self.assertEqual(reconstructed_wrapper._residual_fn(1, 2), 4)
3095
3096    def residual_fn(inputs, outputs):
3097      return inputs * 3 + outputs
3098
3099    wrapper = wrapper_cls(cell, residual_fn=residual_fn)
3100    config = wrapper.get_config()
3101
3102    config_copy = config.copy()
3103    config_copy["cell"] = rnn_cell_impl.LSTMCell.from_config(
3104        config_copy["cell"]["config"])
3105    reconstructed_wrapper = wrapper_cls.from_config(config_copy)
3106    # Assert the reconstructed function will perform the math correctly.
3107    self.assertEqual(reconstructed_wrapper._residual_fn(1, 2), 5)
3108
3109  def testDropoutWrapperSerialization(self):
3110    wrapper_cls = rnn_cell_impl.DropoutWrapper
3111    cell = rnn_cell_impl.LSTMCell(10)
3112    wrapper = wrapper_cls(cell)
3113    config = wrapper.get_config()
3114
3115    config_copy = config.copy()
3116    config_copy["cell"] = rnn_cell_impl.LSTMCell.from_config(
3117        config_copy["cell"]["config"])
3118    reconstructed_wrapper = wrapper_cls.from_config(config_copy)
3119    self.assertDictEqual(config, reconstructed_wrapper.get_config())
3120    self.assertIsInstance(reconstructed_wrapper, wrapper_cls)
3121
3122    wrapper = wrapper_cls(cell, dropout_state_filter_visitor=lambda s: True)
3123    config = wrapper.get_config()
3124
3125    config_copy = config.copy()
3126    config_copy["cell"] = rnn_cell_impl.LSTMCell.from_config(
3127        config_copy["cell"]["config"])
3128    reconstructed_wrapper = wrapper_cls.from_config(config_copy)
3129    self.assertTrue(reconstructed_wrapper._dropout_state_filter(None))
3130
3131    def dropout_state_filter_visitor(unused_state):
3132      return False
3133
3134    wrapper = wrapper_cls(
3135        cell, dropout_state_filter_visitor=dropout_state_filter_visitor)
3136    config = wrapper.get_config()
3137
3138    config_copy = config.copy()
3139    config_copy["cell"] = rnn_cell_impl.LSTMCell.from_config(
3140        config_copy["cell"]["config"])
3141    reconstructed_wrapper = wrapper_cls.from_config(config_copy)
3142    self.assertFalse(reconstructed_wrapper._dropout_state_filter(None))
3143
3144  def testSavedModel(self):
3145    if test_util.is_gpu_available():
3146      self.skipTest("b/175887901")
3147
3148    with self.cached_session():
3149      root = autotrackable.AutoTrackable()
3150      root.cell = rnn_cell_impl.LSTMCell(8)
3151      @def_function.function(input_signature=[tensor_spec.TensorSpec([3, 8])])
3152      def call(x):
3153        state = root.cell.zero_state(3, dtype=x.dtype)
3154        y, _ = root.cell(x, state)
3155        return y
3156      root.call = call
3157      expected = root.call(array_ops.zeros((3, 8)))
3158      self.evaluate(variables_lib.global_variables_initializer())
3159
3160      save_dir = os.path.join(self.get_temp_dir(), "saved_model")
3161      save.save(root, save_dir)
3162      loaded = load.load(save_dir)
3163      self.evaluate(variables_lib.global_variables_initializer())
3164      self.assertAllClose(
3165          expected, loaded.call(array_ops.zeros((3, 8))))
3166
3167
3168@test_util.run_all_in_graph_and_eager_modes
3169@test_util.run_all_without_tensor_float_32(
3170    "Uses an LSTMCell, which calls matmul")
3171class DropoutWrapperTest(test.TestCase, parameterized.TestCase):
3172
3173  def _testDropoutWrapper(self,
3174                          batch_size=None,
3175                          time_steps=None,
3176                          parallel_iterations=None,
3177                          wrapper_type=None,
3178                          scope="root",
3179                          **kwargs):
3180    if batch_size is None and time_steps is None:
3181      # 2 time steps, batch size 1, depth 3
3182      batch_size = 1
3183      time_steps = 2
3184      x = constant_op.constant(
3185          [[[2., 2., 2.]], [[1., 1., 1.]]], dtype=dtypes.float32)
3186      m = rnn_cell_impl.LSTMStateTuple(
3187          *[constant_op.constant([[0.1, 0.1, 0.1]], dtype=dtypes.float32)] * 2)
3188    else:
3189      x = constant_op.constant(
3190          np.random.randn(time_steps, batch_size, 3).astype(np.float32))
3191      m = rnn_cell_impl.LSTMStateTuple(*[
3192          constant_op.
3193          constant([[0.1, 0.1, 0.1]] * batch_size, dtype=dtypes.float32)] * 2)
3194    outputs, final_state = rnn.dynamic_rnn(
3195        cell=wrapper_type(
3196            rnn_cell_impl.LSTMCell(
3197                3, initializer=init_ops.constant_initializer(0.5)),
3198            dtype=x.dtype, **kwargs),
3199        time_major=True,
3200        parallel_iterations=parallel_iterations,
3201        inputs=x,
3202        initial_state=m,
3203        scope=scope)
3204    self.evaluate([variables_lib.global_variables_initializer()])
3205    res = self.evaluate([outputs, final_state])
3206    self.assertEqual(res[0].shape, (time_steps, batch_size, 3))
3207    self.assertEqual(res[1].c.shape, (batch_size, 3))
3208    self.assertEqual(res[1].h.shape, (batch_size, 3))
3209    return res
3210
3211  def testDropoutWrapperProperties(self):
3212    wrapper_type = rnn_cell_impl.DropoutWrapper
3213    cell = rnn_cell_impl.BasicRNNCell(10)
3214    wrapper = wrapper_type(cell)
3215    # Github issue 15810
3216    self.assertEqual(wrapper.wrapped_cell, cell)
3217    self.assertEqual(wrapper.state_size, 10)
3218    self.assertEqual(wrapper.output_size, 10)
3219
3220  def testDropoutWrapperZeroState(self):
3221    wrapper_type = rnn_cell_impl.DropoutWrapper
3222
3223    class _Cell(rnn_cell_impl.BasicRNNCell):
3224
3225      def zero_state(self, batch_size=None, dtype=None):
3226        return "wrapped_cell_zero_state"
3227    wrapper = wrapper_type(_Cell(10))
3228    self.assertEqual(wrapper.zero_state(10, dtypes.float32),
3229                     "wrapped_cell_zero_state")
3230
3231  def testDropoutWrapperKeepAllConstantInput(self):
3232    wrapper_type = rnn_cell_impl.DropoutWrapper
3233    keep = array_ops.ones([])
3234    res = self._testDropoutWrapper(
3235        input_keep_prob=keep, output_keep_prob=keep, state_keep_prob=keep,
3236        wrapper_type=wrapper_type)
3237    true_full_output = np.array(
3238        [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]],
3239        dtype=np.float32)
3240    true_full_final_c = np.array(
3241        [[1.949385, 1.949385, 1.949385]], dtype=np.float32)
3242    self.assertAllClose(true_full_output, res[0])
3243    self.assertAllClose(true_full_output[1], res[1].h)
3244    self.assertAllClose(true_full_final_c, res[1].c)
3245
3246  def testDropoutWrapperKeepAll(self):
3247    wrapper_type = rnn_cell_impl.DropoutWrapper
3248    keep = variable_scope.get_variable("all", initializer=1.0)
3249    res = self._testDropoutWrapper(
3250        input_keep_prob=keep, output_keep_prob=keep, state_keep_prob=keep,
3251        wrapper_type=wrapper_type)
3252    true_full_output = np.array(
3253        [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]],
3254        dtype=np.float32)
3255    true_full_final_c = np.array(
3256        [[1.949385, 1.949385, 1.949385]], dtype=np.float32)
3257    self.assertAllClose(true_full_output, res[0])
3258    self.assertAllClose(true_full_output[1], res[1].h)
3259    self.assertAllClose(true_full_final_c, res[1].c)
3260
3261  def testDropoutWrapperWithSeed(self):
3262    wrapper_type = rnn_cell_impl.DropoutWrapper
3263    keep_some = 0.5
3264    random_seed.set_random_seed(2)
3265    ## Use parallel_iterations = 1 in both calls to
3266    ## _testDropoutWrapper to ensure the (per-time step) dropout is
3267    ## consistent across both calls.  Otherwise the seed may not end
3268    ## up being munged consistently across both graphs.
3269    res_standard_1 = self._testDropoutWrapper(
3270        input_keep_prob=keep_some,
3271        output_keep_prob=keep_some,
3272        state_keep_prob=keep_some,
3273        seed=10,
3274        parallel_iterations=1,
3275        wrapper_type=wrapper_type,
3276        scope="root_1")
3277    random_seed.set_random_seed(2)
3278    res_standard_2 = self._testDropoutWrapper(
3279        input_keep_prob=keep_some,
3280        output_keep_prob=keep_some,
3281        state_keep_prob=keep_some,
3282        seed=10,
3283        parallel_iterations=1,
3284        wrapper_type=wrapper_type,
3285        scope="root_2")
3286    self.assertAllClose(res_standard_1[0], res_standard_2[0])
3287    self.assertAllClose(res_standard_1[1].c, res_standard_2[1].c)
3288    self.assertAllClose(res_standard_1[1].h, res_standard_2[1].h)
3289
3290  def testDropoutWrapperKeepNoOutput(self):
3291    wrapper_type = rnn_cell_impl.DropoutWrapper
3292    keep_all = variable_scope.get_variable("all", initializer=1.0)
3293    keep_none = variable_scope.get_variable("none", initializer=1e-6)
3294    res = self._testDropoutWrapper(
3295        input_keep_prob=keep_all,
3296        output_keep_prob=keep_none,
3297        state_keep_prob=keep_all,
3298        wrapper_type=wrapper_type)
3299    true_full_output = np.array(
3300        [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]],
3301        dtype=np.float32)
3302    true_full_final_c = np.array(
3303        [[1.949385, 1.949385, 1.949385]], dtype=np.float32)
3304    self.assertAllClose(np.zeros(res[0].shape), res[0])
3305    self.assertAllClose(true_full_output[1], res[1].h)
3306    self.assertAllClose(true_full_final_c, res[1].c)
3307
3308  def testDropoutWrapperKeepNoStateExceptLSTMCellMemory(self):
3309    wrapper_type = rnn_cell_impl.DropoutWrapper
3310    keep_all = variable_scope.get_variable("all", initializer=1.0)
3311    keep_none = variable_scope.get_variable("none", initializer=1e-6)
3312    # Even though we dropout state, by default DropoutWrapper never
3313    # drops out the memory ("c") term of an LSTMStateTuple.
3314    res = self._testDropoutWrapper(
3315        input_keep_prob=keep_all,
3316        output_keep_prob=keep_all,
3317        state_keep_prob=keep_none,
3318        wrapper_type=wrapper_type)
3319    true_c_state = np.array([[1.713925, 1.713925, 1.713925]], dtype=np.float32)
3320    true_full_output = np.array(
3321        [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]],
3322        dtype=np.float32)
3323    self.assertAllClose(true_full_output[0], res[0][0])
3324    # Second output is modified by zero input state
3325    self.assertGreater(np.linalg.norm(true_full_output[1] - res[0][1]), 1e-4)
3326    # h state has been set to zero
3327    self.assertAllClose(np.zeros(res[1].h.shape), res[1].h)
3328    # c state of an LSTMStateTuple is NEVER modified.
3329    self.assertAllClose(true_c_state, res[1].c)
3330
3331  def testDropoutWrapperKeepNoInput(self):
3332    wrapper_type = rnn_cell_impl.DropoutWrapper
3333    keep_all = variable_scope.get_variable("all", initializer=1.0)
3334    keep_none = variable_scope.get_variable("none", initializer=1e-6)
3335    true_full_output = np.array(
3336        [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]],
3337        dtype=np.float32)
3338    true_full_final_c = np.array(
3339        [[1.949385, 1.949385, 1.949385]], dtype=np.float32)
3340    # All outputs are different because inputs are zeroed out
3341    res = self._testDropoutWrapper(
3342        input_keep_prob=keep_none,
3343        output_keep_prob=keep_all,
3344        state_keep_prob=keep_all,
3345        wrapper_type=wrapper_type)
3346    self.assertGreater(np.linalg.norm(res[0] - true_full_output), 1e-4)
3347    self.assertGreater(np.linalg.norm(res[1].h - true_full_output[1]), 1e-4)
3348    self.assertGreater(np.linalg.norm(res[1].c - true_full_final_c), 1e-4)
3349
3350  def testDropoutWrapperRecurrentOutput(self):
3351    wrapper_type = rnn_cell_impl.DropoutWrapper
3352    keep_some = 0.8
3353    keep_all = variable_scope.get_variable("all", initializer=1.0)
3354    res = self._testDropoutWrapper(
3355        input_keep_prob=keep_all,
3356        output_keep_prob=keep_some,
3357        state_keep_prob=keep_all,
3358        variational_recurrent=True,
3359        wrapper_type=wrapper_type,
3360        input_size=3,
3361        batch_size=5,
3362        time_steps=7)
3363    # Ensure the same dropout pattern for all time steps
3364    output_mask = np.abs(res[0]) > 1e-6
3365    for m in output_mask[1:]:
3366      self.assertAllClose(output_mask[0], m)
3367
3368  def testDropoutWrapperRecurrentStateInputAndOutput(self):
3369    wrapper_type = rnn_cell_impl.DropoutWrapper
3370    keep_some = 0.9
3371    res = self._testDropoutWrapper(
3372        input_keep_prob=keep_some,
3373        output_keep_prob=keep_some,
3374        state_keep_prob=keep_some,
3375        variational_recurrent=True,
3376        wrapper_type=wrapper_type,
3377        input_size=3,
3378        batch_size=5,
3379        time_steps=7)
3380
3381    # Smoke test for the state/input masks.
3382    output_mask = np.abs(res[0]) > 1e-6
3383    for time_step in output_mask:
3384      # Ensure the same dropout output pattern for all time steps
3385      self.assertAllClose(output_mask[0], time_step)
3386      for batch_entry in time_step:
3387        # Assert all batch entries get the same mask
3388        self.assertAllClose(batch_entry, time_step[0])
3389
3390    # For state, ensure all batch entries have the same mask
3391    state_c_mask = np.abs(res[1].c) > 1e-6
3392    state_h_mask = np.abs(res[1].h) > 1e-6
3393    for batch_entry in state_c_mask:
3394      self.assertAllClose(batch_entry, state_c_mask[0])
3395    for batch_entry in state_h_mask:
3396      self.assertAllClose(batch_entry, state_h_mask[0])
3397
3398  def testDropoutWrapperRecurrentStateInputAndOutputWithSeed(self):
3399    wrapper_type = rnn_cell_impl.DropoutWrapper
3400    keep_some = 0.9
3401    random_seed.set_random_seed(2347)
3402    np.random.seed(23487)
3403    res0 = self._testDropoutWrapper(
3404        input_keep_prob=keep_some,
3405        output_keep_prob=keep_some,
3406        state_keep_prob=keep_some,
3407        variational_recurrent=True,
3408        wrapper_type=wrapper_type,
3409        input_size=3,
3410        batch_size=5,
3411        time_steps=7,
3412        seed=-234987,
3413        scope="root_0")
3414    random_seed.set_random_seed(2347)
3415    np.random.seed(23487)
3416    res1 = self._testDropoutWrapper(
3417        input_keep_prob=keep_some,
3418        output_keep_prob=keep_some,
3419        state_keep_prob=keep_some,
3420        variational_recurrent=True,
3421        wrapper_type=wrapper_type,
3422        input_size=3,
3423        batch_size=5,
3424        time_steps=7,
3425        seed=-234987,
3426        scope="root_1")
3427
3428    output_mask = np.abs(res0[0]) > 1e-6
3429    for time_step in output_mask:
3430      # Ensure the same dropout output pattern for all time steps
3431      self.assertAllClose(output_mask[0], time_step)
3432      for batch_entry in time_step:
3433        # Assert all batch entries get the same mask
3434        self.assertAllClose(batch_entry, time_step[0])
3435
3436    # For state, ensure all batch entries have the same mask
3437    state_c_mask = np.abs(res0[1].c) > 1e-6
3438    state_h_mask = np.abs(res0[1].h) > 1e-6
3439    for batch_entry in state_c_mask:
3440      self.assertAllClose(batch_entry, state_c_mask[0])
3441    for batch_entry in state_h_mask:
3442      self.assertAllClose(batch_entry, state_h_mask[0])
3443
3444    # Ensure seeded calculation is identical.
3445    self.assertAllClose(res0[0], res1[0])
3446    self.assertAllClose(res0[1].c, res1[1].c)
3447    self.assertAllClose(res0[1].h, res1[1].h)
3448
3449
3450if __name__ == "__main__":
3451  test.main()
3452