xref: /aosp_15_r20/external/tensorflow/tensorflow/python/kernel_tests/nn_ops/rnn_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for rnn module."""
16
17import os
18import time
19import timeit
20
21import numpy as np
22
23from tensorflow.core.protobuf import config_pb2
24from tensorflow.python.client import session
25from tensorflow.python.eager import context
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import ops as ops_lib
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.framework import test_util
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import control_flow_ops
33from tensorflow.python.ops import gradients_impl
34from tensorflow.python.ops import init_ops
35from tensorflow.python.ops import rnn
36from tensorflow.python.ops import rnn_cell_impl
37from tensorflow.python.ops import tensor_array_ops
38from tensorflow.python.ops import variables as variables_lib
39import tensorflow.python.ops.data_flow_grad  # pylint: disable=unused-import
40import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
41import tensorflow.python.ops.sparse_grad  # pylint: disable=unused-import
42import tensorflow.python.ops.tensor_array_grad  # pylint: disable=unused-import
43from tensorflow.python.platform import test
44from tensorflow.python.training import saver
45
46
47class Plus1RNNCell(rnn_cell_impl.RNNCell):
48  """RNN Cell generating (output, new_state) = (input + 1, state + 1)."""
49
50  @property
51  def output_size(self):
52    return 5
53
54  @property
55  def state_size(self):
56    return 5
57
58  def call(self, input_, state, scope=None):
59    return (input_ + 1, state + 1)
60
61
62class ScalarStateRNNCell(rnn_cell_impl.RNNCell):
63  """RNN Cell generating (output, new_state) = (input + 1, state + 1)."""
64
65  @property
66  def output_size(self):
67    return 1
68
69  @property
70  def state_size(self):
71    return tensor_shape.TensorShape([])
72
73  def zero_state(self, batch_size, dtype):
74    return array_ops.zeros([], dtype=dtypes.int32)
75
76  def call(self, input_, state, scope=None):
77    return (input_, state + 1)
78
79
80class UnbalancedOutputRNNCell(rnn_cell_impl.RNNCell):
81  """RNN Cell generating (output, new_state) = (input + 1, state + 1)."""
82
83  @property
84  def output_size(self):
85    return  tensor_shape.TensorShape(1), tensor_shape.TensorShape((2))
86
87  @property
88  def state_size(self):
89    return tensor_shape.TensorShape([])
90
91  def zero_state(self, batch_size, dtype):
92    return array_ops.zeros([], dtype=dtypes.int32)
93
94  def call(self, input_, state, scope=None):
95    concatenated = array_ops.concat((input_, input_), axis=-1)
96    return (input_, concatenated), state + 1
97
98
99class TensorArrayStateRNNCell(rnn_cell_impl.RNNCell):
100  """RNN Cell its state as a TensorArray."""
101
102  @property
103  def output_size(self):
104    return 1
105
106  @property
107  def state_size(self):
108    return (tensor_shape.TensorShape([]), ())
109
110  def zero_state(self, batch_size, dtype):
111    return (array_ops.zeros([], dtype=dtypes.int32),
112            tensor_array_ops.TensorArray(
113                dtype=dtype, size=0, dynamic_size=True))
114
115  def call(self, input_, state, scope=None):
116    new_array = state[1].write(state[0], input_)
117    return (input_, (state[0] + 1, new_array))
118
119
120class RNNTest(test.TestCase):
121
122  def setUp(self):
123    self._seed = 23489
124    np.random.seed(self._seed)
125
126  @test_util.run_in_graph_and_eager_modes
127  def testInvalidSequenceLengthShape(self):
128    cell = Plus1RNNCell()
129    if context.executing_eagerly():
130      inputs = [constant_op.constant(np.ones((3, 4)))]
131    else:
132      inputs = [array_ops.placeholder(dtypes.float32, shape=(3, 4))]
133    with self.assertRaisesRegex(ValueError, "must be a vector"):
134      rnn.dynamic_rnn(
135          cell,
136          array_ops.stack(inputs),
137          dtype=dtypes.float32,
138          sequence_length=[[4]])
139
140  @test_util.run_in_graph_and_eager_modes
141  def testInvalidDtype(self):
142    if context.executing_eagerly():
143      inputs = np.zeros((3, 4, 5), dtype=np.int32)
144    else:
145      inputs = array_ops.placeholder(dtypes.int32, shape=(3, 4, 5))
146
147    cells = [
148        rnn_cell_impl.BasicRNNCell,
149        rnn_cell_impl.GRUCell,
150        rnn_cell_impl.BasicLSTMCell,
151        rnn_cell_impl.LSTMCell,
152    ]
153    for cell_cls in cells:
154      with self.cached_session():
155        with self.assertRaisesRegex(ValueError,
156                                    "RNN cell only supports floating"):
157          cell = cell_cls(2, dtype=dtypes.int32)
158          rnn.dynamic_rnn(cell, inputs, dtype=dtypes.int32)
159
160  @test_util.run_in_graph_and_eager_modes
161  def testBatchSizeFromInput(self):
162    cell = Plus1RNNCell()
163    in_eager_mode = context.executing_eagerly()
164    # With static batch size
165    if in_eager_mode:
166      inputs = np.zeros((3, 4, 5), dtype=np.float32)
167      initial_state = np.zeros((3, 5), dtype=np.float32)
168    else:
169      inputs = array_ops.placeholder(dtypes.float32, shape=(3, 4, 5))
170      initial_state = array_ops.placeholder(dtypes.float32, shape=(3, 5))
171
172    # - Without initial_state
173    outputs, state = rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32)
174    self.assertEqual(3, outputs.shape[0])
175    self.assertEqual(3, state.shape[0])
176
177    # - With initial_state
178    outputs, state = rnn.dynamic_rnn(
179        cell, inputs, initial_state=initial_state)
180    self.assertEqual(3, outputs.shape[0])
181    self.assertEqual(3, state.shape[0])
182
183    # Without static batch size
184    # Tensor shapes are fully determined with eager execution enabled,
185    # so only run this test for graph construction.
186    if not in_eager_mode:
187      inputs = array_ops.placeholder(dtypes.float32, shape=(None, 4, 5))
188      # - Without initial_state
189      outputs, state = rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32)
190      self.assertEqual(None, outputs.shape.dims[0].value)
191      self.assertEqual(None, state.shape.dims[0].value)
192      # - With initial_state
193      outputs, state = rnn.dynamic_rnn(
194          cell,
195          inputs,
196          initial_state=array_ops.placeholder(dtypes.float32, shape=(None, 5)))
197      self.assertEqual(None, outputs.shape.dims[0].value)
198      self.assertEqual(None, state.shape.dims[0].value)
199
200  @test_util.run_in_graph_and_eager_modes
201  def testScalarStateIsAccepted(self):
202    cell = ScalarStateRNNCell()
203    in_eager_mode = context.executing_eagerly()
204
205    if in_eager_mode:
206      inputs = np.array([[[1], [2], [3], [4]]], dtype=np.float32)
207    else:
208      inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1))
209
210    with self.cached_session() as sess:
211      outputs, state = rnn.dynamic_rnn(
212          cell, inputs, dtype=dtypes.float32, sequence_length=[4])
213      if not in_eager_mode:
214        outputs, state = sess.run(
215            [outputs, state], feed_dict={inputs: [[[1], [2], [3], [4]]]})
216
217    self.assertAllEqual([[[1], [2], [3], [4]]], outputs)
218    self.assertAllEqual(4, state)
219
220  @test_util.run_in_graph_and_eager_modes
221  def testUnbalancedOutputIsAccepted(self):
222    cell = UnbalancedOutputRNNCell()
223    in_eager_mode = context.executing_eagerly()
224
225    if in_eager_mode:
226      inputs = np.array([[[1], [2], [3], [4]]], dtype=np.float32)
227    else:
228      inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1))
229
230    with self.cached_session() as sess:
231      outputs, state = rnn.dynamic_rnn(
232          cell, inputs, dtype=dtypes.float32, sequence_length=[4])
233      if not in_eager_mode:
234        outputs, state = sess.run(
235            [outputs, state], feed_dict={inputs: [[[1], [2], [3], [4]]]})
236
237    self.assertIsInstance(outputs, tuple)
238    self.assertAllEqual([[[1], [2], [3], [4]]], outputs[0])
239    self.assertAllEqual([[[1, 1], [2, 2], [3, 3], [4, 4]]], outputs[1])
240    self.assertAllEqual(4, state)
241
242  @test_util.assert_no_new_pyobjects_executing_eagerly
243  def testEagerMemory(self):
244    with context.eager_mode():
245      cell = TensorArrayStateRNNCell()
246      inputs = np.array([[[1], [2], [3], [4]]], dtype=np.float32)
247      rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32, sequence_length=[4])
248
249  @test_util.run_in_graph_and_eager_modes
250  @test_util.run_v1_only("b/120545219")
251  def testTensorArrayStateIsAccepted(self):
252    cell = TensorArrayStateRNNCell()
253    in_eager_mode = context.executing_eagerly()
254
255    if in_eager_mode:
256      inputs = np.array([[[1], [2], [3], [4]]], dtype=np.float32)
257    else:
258      inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1))
259
260    with self.cached_session() as sess:
261      outputs, state = rnn.dynamic_rnn(
262          cell, inputs, dtype=dtypes.float32, sequence_length=[4])
263      state = (state[0], state[1].stack())
264      if not in_eager_mode:
265        outputs, state = sess.run(
266            [outputs, state], feed_dict={
267                inputs: [[[1], [2], [3], [4]]]
268            })
269
270    self.assertAllEqual([[[1], [2], [3], [4]]], outputs)
271    self.assertAllEqual(4, state[0])
272    self.assertAllEqual([[[1]], [[2]], [[3]], [[4]]], state[1])
273
274  @test_util.run_deprecated_v1
275  def testCellGetInitialState(self):
276    cell = rnn_cell_impl.BasicRNNCell(5)
277    with self.assertRaisesRegex(ValueError,
278                                "batch_size and dtype cannot be None"):
279      cell.get_initial_state(None, None, None)
280
281    inputs = array_ops.placeholder(dtypes.float32, shape=(None, 4, 1))
282    with self.assertRaisesRegex(
283        ValueError, "batch size from input tensor is different from"):
284      cell.get_initial_state(inputs=inputs, batch_size=50, dtype=None)
285
286    with self.assertRaisesRegex(
287        ValueError, "batch size from input tensor is different from"):
288      cell.get_initial_state(
289          inputs=inputs, batch_size=constant_op.constant(50), dtype=None)
290
291    with self.assertRaisesRegex(ValueError,
292                                "dtype from input tensor is different from"):
293      cell.get_initial_state(inputs=inputs, batch_size=None, dtype=dtypes.int16)
294
295    initial_state = cell.get_initial_state(
296        inputs=inputs, batch_size=None, dtype=None)
297    self.assertEqual(initial_state.shape.as_list(), [None, 5])
298    self.assertEqual(initial_state.dtype, inputs.dtype)
299
300    batch = array_ops.shape(inputs)[0]
301    dtype = inputs.dtype
302    initial_state = cell.get_initial_state(None, batch, dtype)
303    self.assertEqual(initial_state.shape.as_list(), [None, 5])
304    self.assertEqual(initial_state.dtype, inputs.dtype)
305
306  def _assert_cell_builds(self, cell_class, dtype, batch_size, in_size,
307                          out_size):
308    cell = cell_class(out_size, dtype=dtype)
309    in_shape = tensor_shape.TensorShape((batch_size, in_size))
310    cell.build(in_shape)
311    state_output = cell.get_initial_state(
312        inputs=None, batch_size=batch_size, dtype=dtype)
313    cell_output, _ = cell(array_ops.zeros(in_shape, dtype), state_output)
314    self.assertAllEqual([batch_size, out_size], cell_output.shape.as_list())
315
316  @test_util.run_in_graph_and_eager_modes
317  def testCellsBuild(self):
318    f32 = dtypes.float32
319    f64 = dtypes.float64
320    self._assert_cell_builds(rnn_cell_impl.BasicRNNCell, f32, 5, 7, 3)
321    self._assert_cell_builds(rnn_cell_impl.BasicRNNCell, f64, 5, 7, 3)
322    self._assert_cell_builds(rnn_cell_impl.BasicLSTMCell, f32, 5, 7, 3)
323    self._assert_cell_builds(rnn_cell_impl.BasicLSTMCell, f64, 5, 7, 3)
324    self._assert_cell_builds(rnn_cell_impl.GRUCell, f32, 5, 7, 3)
325    self._assert_cell_builds(rnn_cell_impl.GRUCell, f64, 5, 7, 3)
326    self._assert_cell_builds(rnn_cell_impl.LSTMCell, f32, 5, 7, 3)
327    self._assert_cell_builds(rnn_cell_impl.LSTMCell, f64, 5, 7, 3)
328
329  @test_util.run_deprecated_v1
330  def testBasicLSTMCellInterchangeWithLSTMCell(self):
331    with self.session(graph=ops_lib.Graph()) as sess:
332      basic_cell = rnn_cell_impl.BasicLSTMCell(1)
333      basic_cell(array_ops.ones([1, 1]),
334                 state=basic_cell.get_initial_state(inputs=None,
335                                                    batch_size=1,
336                                                    dtype=dtypes.float32))
337      self.evaluate([v.initializer for v in basic_cell.variables])
338      self.evaluate(basic_cell._bias.assign([10.] * 4))
339      save = saver.Saver()
340      prefix = os.path.join(self.get_temp_dir(), "ckpt")
341      save_path = save.save(sess, prefix)
342
343    with self.session(graph=ops_lib.Graph()) as sess:
344      lstm_cell = rnn_cell_impl.LSTMCell(1, name="basic_lstm_cell")
345      lstm_cell(array_ops.ones([1, 1]),
346                state=lstm_cell.get_initial_state(inputs=None,
347                                                  batch_size=1,
348                                                  dtype=dtypes.float32))
349      self.evaluate([v.initializer for v in lstm_cell.variables])
350      save = saver.Saver()
351      save.restore(sess, save_path)
352      self.assertAllEqual([10.] * 4, self.evaluate(lstm_cell._bias))
353
354######### Benchmarking RNN code
355
356
357def _static_vs_dynamic_rnn_benchmark_static(inputs_list_t, sequence_length):
358  (_, input_size) = inputs_list_t[0].get_shape().as_list()
359  initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=127)
360  cell = rnn_cell_impl.LSTMCell(
361      num_units=input_size,
362      use_peepholes=True,
363      initializer=initializer,
364      state_is_tuple=False)
365  outputs, final_state = rnn.static_rnn(
366      cell,
367      inputs_list_t,
368      sequence_length=sequence_length,
369      dtype=dtypes.float32)
370
371  trainable_variables = ops_lib.get_collection(
372      ops_lib.GraphKeys.TRAINABLE_VARIABLES)
373  gradients = gradients_impl.gradients(outputs + [final_state],
374                                       trainable_variables)
375
376  return control_flow_ops.group(final_state, *(gradients + outputs))
377
378
379def _static_vs_dynamic_rnn_benchmark_dynamic(inputs_t, sequence_length):
380  (unused_0, unused_1, input_size) = inputs_t.get_shape().as_list()
381  initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=127)
382  cell = rnn_cell_impl.LSTMCell(
383      num_units=input_size,
384      use_peepholes=True,
385      initializer=initializer,
386      state_is_tuple=False)
387  outputs, final_state = rnn.dynamic_rnn(
388      cell, inputs_t, sequence_length=sequence_length, dtype=dtypes.float32)
389
390  trainable_variables = ops_lib.get_collection(
391      ops_lib.GraphKeys.TRAINABLE_VARIABLES)
392  gradients = gradients_impl.gradients([outputs, final_state],
393                                       trainable_variables)
394
395  return control_flow_ops.group(final_state, outputs, *gradients)
396
397
398def graph_creation_static_vs_dynamic_rnn_benchmark(max_time):
399  config = config_pb2.ConfigProto()
400  config.allow_soft_placement = True
401
402  # These parameters don't matter
403  batch_size = 512
404  num_units = 512
405
406  # Set up sequence lengths
407  np.random.seed([127])
408  sequence_length = np.random.randint(0, max_time, size=batch_size)
409  inputs_list = [
410      np.random.randn(batch_size, num_units).astype(np.float32)
411      for _ in range(max_time)
412  ]
413  inputs = np.dstack(inputs_list).transpose([0, 2, 1])  # batch x time x depth
414
415  def _create_static_rnn():
416    with session.Session(config=config, graph=ops_lib.Graph()):
417      inputs_list_t = [
418          variables_lib.Variable(
419              x, trainable=False).value() for x in inputs_list
420      ]
421      _static_vs_dynamic_rnn_benchmark_static(inputs_list_t, sequence_length)
422
423  def _create_dynamic_rnn():
424    with session.Session(config=config, graph=ops_lib.Graph()):
425      inputs_t = variables_lib.Variable(inputs, trainable=False).value()
426      _static_vs_dynamic_rnn_benchmark_dynamic(inputs_t, sequence_length)
427
428  delta_static = timeit.timeit(_create_static_rnn, number=5)
429  delta_dynamic = timeit.timeit(_create_dynamic_rnn, number=5)
430
431  print("%d \t %f \t %f \t %f" %
432        (max_time, delta_static, delta_dynamic, delta_dynamic / delta_static))
433  return delta_static, delta_dynamic
434
435
436def _timer(sess, ops):
437  # Warm in
438  for _ in range(2):
439    sess.run(ops)
440
441  # Timing run
442  runs = 20
443  start = time.time()
444  for _ in range(runs):
445    sess.run(ops)
446  end = time.time()
447  return (end - start) / float(runs)
448
449
450def static_vs_dynamic_rnn_benchmark(batch_size, max_time, num_units, use_gpu):
451  config = config_pb2.ConfigProto()
452  config.allow_soft_placement = True
453
454  # Set up sequence lengths
455  np.random.seed([127])
456  sequence_length = np.random.randint(0, max_time, size=batch_size)
457  inputs_list = [
458      np.random.randn(batch_size, num_units).astype(np.float32)
459      for _ in range(max_time)
460  ]
461  inputs = np.dstack(inputs_list).transpose([0, 2, 1])  # batch x time x depth
462
463  # Using rnn()
464  with session.Session(config=config, graph=ops_lib.Graph()) as sess:
465    with ops_lib.device("/cpu:0" if not use_gpu else None):
466      inputs_list_t = [
467          variables_lib.Variable(
468              x, trainable=False).value() for x in inputs_list
469      ]
470      ops = _static_vs_dynamic_rnn_benchmark_static(inputs_list_t,
471                                                    sequence_length)
472    variables_lib.global_variables_initializer().run()
473    delta_static = _timer(sess, ops)
474
475  # Using dynamic_rnn()
476  with session.Session(config=config, graph=ops_lib.Graph()) as sess:
477    with ops_lib.device("/cpu:0" if not use_gpu else None):
478      inputs_t = variables_lib.Variable(inputs, trainable=False).value()
479      ops = _static_vs_dynamic_rnn_benchmark_dynamic(inputs_t, sequence_length)
480    variables_lib.global_variables_initializer().run()
481    delta_dynamic = _timer(sess, ops)
482
483  print("%d \t %d \t %d \t %s \t %f \t %f \t %f" %
484        (batch_size, max_time, num_units, use_gpu, delta_static, delta_dynamic,
485         delta_dynamic / delta_static))
486
487  return delta_static, delta_dynamic
488
489
490def _half_seq_len_vs_unroll_half_rnn_benchmark(inputs_list_t, sequence_length):
491  (_, input_size) = inputs_list_t[0].get_shape().as_list()
492  initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=127)
493  cell = rnn_cell_impl.LSTMCell(
494      num_units=input_size,
495      use_peepholes=True,
496      initializer=initializer,
497      state_is_tuple=False)
498  outputs, final_state = rnn.static_rnn(
499      cell,
500      inputs_list_t,
501      sequence_length=sequence_length,
502      dtype=dtypes.float32)
503
504  trainable_variables = ops_lib.get_collection(
505      ops_lib.GraphKeys.TRAINABLE_VARIABLES)
506  gradients = gradients_impl.gradients(outputs + [final_state],
507                                       trainable_variables)
508
509  return control_flow_ops.group(final_state, *(gradients + outputs))
510
511
512def half_seq_len_vs_unroll_half_rnn_benchmark(batch_size, max_time, num_units,
513                                              use_gpu):
514  config = config_pb2.ConfigProto()
515  config.allow_soft_placement = True
516
517  # Set up sequence lengths
518  np.random.seed([127])
519  sequence_length = max_time * np.ones((batch_size,))
520  inputs_list = [
521      np.random.randn(batch_size, num_units).astype(np.float32)
522      for _ in range(max_time)
523  ]
524
525  # Halve the sequence length, full static unroll
526  with session.Session(config=config, graph=ops_lib.Graph()) as sess:
527    with ops_lib.device("/cpu:0" if not use_gpu else None):
528      inputs_list_t = [
529          variables_lib.Variable(
530              x, trainable=False).value() for x in inputs_list
531      ]
532      ops = _half_seq_len_vs_unroll_half_rnn_benchmark(inputs_list_t,
533                                                       sequence_length / 2)
534    variables_lib.global_variables_initializer().run()
535    delta_half_seq_len = _timer(sess, ops)
536
537  # Halve the unroll size, don't use sequence length
538  with session.Session(config=config, graph=ops_lib.Graph()) as sess:
539    with ops_lib.device("/cpu:0" if not use_gpu else None):
540      inputs_list_t = [
541          variables_lib.Variable(
542              x, trainable=False).value() for x in inputs_list
543      ]
544      ops = _half_seq_len_vs_unroll_half_rnn_benchmark(
545          inputs_list_t[:(max_time // 2)], sequence_length / 2)
546    variables_lib.global_variables_initializer().run()
547    delta_unroll_half = _timer(sess, ops)
548  print("%d \t %d \t\t %d \t %s \t %f \t\t %f \t\t %f" %
549        (batch_size, max_time, num_units, use_gpu, delta_half_seq_len,
550         delta_unroll_half, delta_half_seq_len / delta_unroll_half))
551
552  return delta_half_seq_len, delta_unroll_half
553
554
555def _concat_state_vs_tuple_state_rnn_benchmark(inputs_list_t, sequence_length,
556                                               state_is_tuple):
557  (_, input_size) = inputs_list_t[0].get_shape().as_list()
558  initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=127)
559  cell = rnn_cell_impl.LSTMCell(
560      num_units=input_size,
561      use_peepholes=True,
562      initializer=initializer,
563      state_is_tuple=state_is_tuple)
564  outputs, final_state = rnn.static_rnn(
565      cell,
566      inputs_list_t,
567      sequence_length=sequence_length,
568      dtype=dtypes.float32)
569
570  final_state = list(final_state) if state_is_tuple else [final_state]
571
572  trainable_variables = ops_lib.get_collection(
573      ops_lib.GraphKeys.TRAINABLE_VARIABLES)
574  gradients = gradients_impl.gradients(outputs + final_state,
575                                       trainable_variables)
576
577  return control_flow_ops.group(*(final_state + gradients + outputs))
578
579
580def concat_state_vs_tuple_state_rnn_benchmark(batch_size, max_time, num_units,
581                                              use_gpu):
582  config = config_pb2.ConfigProto()
583  config.allow_soft_placement = True
584
585  # Set up sequence lengths
586  np.random.seed([127])
587  sequence_length = max_time * np.ones((batch_size,))
588  inputs_list = [
589      np.random.randn(batch_size, num_units).astype(np.float32)
590      for _ in range(max_time)
591  ]
592
593  # Run with concatenated states (default)
594  with session.Session(config=config, graph=ops_lib.Graph()) as sess:
595    with ops_lib.device("/cpu:0" if not use_gpu else None):
596      inputs_list_t = [
597          variables_lib.Variable(
598              x, trainable=False).value() for x in inputs_list
599      ]
600      ops = _concat_state_vs_tuple_state_rnn_benchmark(
601          inputs_list_t, sequence_length, state_is_tuple=False)
602    variables_lib.global_variables_initializer().run()
603    delta_concat_state = _timer(sess, ops)
604
605  # Run with tuple states (new)
606  with session.Session(config=config, graph=ops_lib.Graph()) as sess:
607    with ops_lib.device("/cpu:0" if not use_gpu else None):
608      inputs_list_t = [
609          variables_lib.Variable(
610              x, trainable=False).value() for x in inputs_list
611      ]
612      ops = _concat_state_vs_tuple_state_rnn_benchmark(
613          inputs_list_t, sequence_length, state_is_tuple=True)
614    variables_lib.global_variables_initializer().run()
615    delta_tuple_state = _timer(sess, ops)
616  print("%d \t %d \t %d \t %s \t %f \t\t %f \t\t %f" %
617        (batch_size, max_time, num_units, use_gpu, delta_concat_state,
618         delta_tuple_state, delta_concat_state / delta_tuple_state))
619
620  return delta_concat_state, delta_tuple_state
621
622
623def _dynamic_rnn_swap_memory_benchmark(inputs_t, sequence_length, swap_memory):
624  (unused_0, unused_1, input_size) = inputs_t.get_shape().as_list()
625  initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=127)
626  cell = rnn_cell_impl.LSTMCell(
627      num_units=input_size,
628      use_peepholes=True,
629      initializer=initializer,
630      state_is_tuple=False)
631  outputs, final_state = rnn.dynamic_rnn(
632      cell,
633      inputs_t,
634      sequence_length=sequence_length,
635      swap_memory=swap_memory,
636      dtype=dtypes.float32)
637
638  trainable_variables = ops_lib.get_collection(
639      ops_lib.GraphKeys.TRAINABLE_VARIABLES)
640  gradients = gradients_impl.gradients([outputs, final_state],
641                                       trainable_variables)
642
643  return control_flow_ops.group(final_state, outputs, *gradients)
644
645
646def dynamic_rnn_swap_memory_benchmark(batch_size, max_time, num_units):
647  config = config_pb2.ConfigProto()
648  config.allow_soft_placement = True
649
650  # Set up sequence lengths
651  np.random.seed([127])
652  sequence_length = np.random.randint(0, max_time, size=batch_size)
653  inputs_list = [
654      np.random.randn(batch_size, num_units).astype(np.float32)
655      for _ in range(max_time)
656  ]
657  inputs = np.dstack(inputs_list).transpose([0, 2, 1])  # batch x time x depth
658
659  # No memory swap
660  with session.Session(config=config, graph=ops_lib.Graph()) as sess:
661    inputs_t = variables_lib.Variable(inputs, trainable=False).value()
662    ops = _dynamic_rnn_swap_memory_benchmark(
663        inputs_t, sequence_length, swap_memory=False)
664    variables_lib.global_variables_initializer().run()
665    no_swap = _timer(sess, ops)
666
667  # Memory swap
668  with session.Session(config=config, graph=ops_lib.Graph()) as sess:
669    inputs_t = variables_lib.Variable(inputs, trainable=False).value()
670    ops = _dynamic_rnn_swap_memory_benchmark(
671        inputs_t, sequence_length, swap_memory=True)
672    variables_lib.global_variables_initializer().run()
673    swap = _timer(sess, ops)
674
675  print("%d \t %d \t %d \t %f \t %f \t %f" %
676        (batch_size, max_time, num_units, no_swap, swap, swap / no_swap))
677  return no_swap, swap
678
679
680def rnn_long_sequence_benchmark(batch_size, seqlen, num_units, dynamic,
681                                swap_memory, nn):
682  config = config_pb2.ConfigProto()
683  config.allow_soft_placement = True
684
685  # Set up sequence lengths
686  np.random.seed([127])
687  sequence_length = [seqlen for _ in range(batch_size)]
688  inputs_list = [
689      np.random.randn(batch_size, num_units).astype(np.float32)
690      for _ in range(seqlen)
691  ]
692  inputs = np.dstack(inputs_list).transpose([0, 2, 1])  # batch x time x depth
693
694  for _ in range(nn):
695    if dynamic:
696      with session.Session(config=config, graph=ops_lib.Graph()) as sess:
697        inputs_t = variables_lib.Variable(inputs, trainable=False).value()
698        ops = _dynamic_rnn_swap_memory_benchmark(
699            inputs_t, sequence_length, swap_memory=swap_memory)
700        variables_lib.global_variables_initializer().run()
701        elapsed = _timer(sess, ops)
702    else:
703      with session.Session(config=config, graph=ops_lib.Graph()) as sess:
704        inputs_list_t = [
705            variables_lib.Variable(
706                x, trainable=False).value() for x in inputs_list
707        ]
708        ops = _static_vs_dynamic_rnn_benchmark_static(inputs_list_t,
709                                                      sequence_length)
710        variables_lib.global_variables_initializer().run()
711        elapsed = _timer(sess, ops)
712
713    print("%d \t %d \t %d \t %s \t %f \t %f" % (batch_size, seqlen, num_units,
714                                                dynamic, elapsed,
715                                                elapsed / seqlen))
716
717
718class BenchmarkRNN(test.Benchmark):
719
720  def benchmarkGraphCreationStaticVsDynamicLSTM(self):
721    print("Graph Creation: Static Unroll vs. Dynamic Unroll LSTM")
722    print("max_t \t dt(static) \t dt(dynamic) \t dt(dynamic)/dt(static)")
723    for max_time in (1, 25, 50):
724      s_dt, d_dt = graph_creation_static_vs_dynamic_rnn_benchmark(max_time)
725      self.report_benchmark(
726          name="graph_creation_time_static_T%02d" % max_time,
727          iters=5,
728          wall_time=s_dt)
729      self.report_benchmark(
730          name="graph_creation_time_dynamic_T%02d" % max_time,
731          iters=5,
732          wall_time=d_dt)
733
734  def benchmarkStaticUnrollVsDynamicFlowLSTM(self):
735    print("Calculation: Static Unroll with Dynamic Flow LSTM "
736          "vs. Dynamic Unroll LSTM")
737    print("batch \t max_t \t units \t gpu \t dt(static) \t dt(dynamic) "
738          "\t dt(dynamic)/dt(static)")
739    for batch_size in (256,):
740      for max_time in (50,):
741        for num_units in (512, 256, 128):
742          for use_gpu in (False, True):
743            s_dt, d_dt = static_vs_dynamic_rnn_benchmark(batch_size, max_time,
744                                                         num_units, use_gpu)
745            self.report_benchmark(
746                name="static_unroll_time_T%02d_B%03d_N%03d_gpu_%s" %
747                (max_time, batch_size, num_units, use_gpu),
748                iters=20,
749                wall_time=s_dt)
750            self.report_benchmark(
751                name="dynamic_unroll_time_T%02d_B%03d_N%03d_gpu_%s" %
752                (max_time, batch_size, num_units, use_gpu),
753                iters=20,
754                wall_time=d_dt)
755
756  def benchmarkDynamicLSTMNoMemorySwapVsMemorySwap(self):
757    print("Calculation: Dynamic LSTM No Memory Swap vs. Memory Swap")
758    print("batch \t max_t \t units \t no_swap \t swap \t swap/no_swap")
759    for batch_size in (256, 512):
760      for max_time in (100,):
761        for num_units in (512, 256, 128):
762          no_swap, swap = dynamic_rnn_swap_memory_benchmark(batch_size,
763                                                            max_time, num_units)
764          self.report_benchmark(
765              name="dynamic_lstm_no_memory_swap_T%02d_B%03d_N%03d" %
766              (max_time, batch_size, num_units),
767              iters=20,
768              wall_time=no_swap)
769          self.report_benchmark(
770              name="dynamic_lstm_with_memory_swap_T%02d_B%03d_N%03d" %
771              (max_time, batch_size, num_units),
772              iters=20,
773              wall_time=swap)
774
775  def benchmarkStaticUnrollHalfSequenceLengthVsHalfUnroll(self):
776    print("Calculation: Static Unroll with Halved Sequence Length "
777          "vs. Half Static Unroll")
778    print("batch \t full_t \t units \t gpu \t dt(half_seq_len) "
779          "\t dt(unroll_half) \t dt(half_seq_len)/dt(unroll_half)")
780    for batch_size in (128,):
781      for max_time in (50,):
782        for num_units in (256,):
783          for use_gpu in (False, True):
784            s_dt, d_dt = half_seq_len_vs_unroll_half_rnn_benchmark(batch_size,
785                                                                   max_time,
786                                                                   num_units,
787                                                                   use_gpu)
788            self.report_benchmark(
789                name="half_seq_len_time_T%02d_B%03d_N%03d_gpu_%s" %
790                (max_time, batch_size, num_units, use_gpu),
791                iters=20,
792                wall_time=s_dt)
793            self.report_benchmark(
794                name="unroll_half_time_T%02d_B%03d_N%03d_gpu_%s" %
795                (max_time, batch_size, num_units, use_gpu),
796                iters=20,
797                wall_time=d_dt)
798
799  def benchmarkStaticUnrollStateConcatVsStateTuple(self):
800    print("Calculation: Static Unroll with Concatenated State "
801          "vs. Tuple State")
802    print("batch \t time \t units \t gpu \t dt(concat_state) "
803          "\t dt(tuple_state) \t dt(concat_state)/dt(tuple_state)")
804    for batch_size in (
805        16,
806        128,):
807      for max_time in (50,):
808        for num_units in (
809            16,
810            128,):
811          for use_gpu in (False, True):
812            c_dt, t_dt = concat_state_vs_tuple_state_rnn_benchmark(batch_size,
813                                                                   max_time,
814                                                                   num_units,
815                                                                   use_gpu)
816            self.report_benchmark(
817                name="concat_state_time_T%02d_B%03d_N%03d_gpu_%s" %
818                (max_time, batch_size, num_units, use_gpu),
819                iters=20,
820                wall_time=c_dt)
821            self.report_benchmark(
822                name="tuple_state_time_T%02d_B%03d_N%03d_gpu_%s" %
823                (max_time, batch_size, num_units, use_gpu),
824                iters=20,
825                wall_time=t_dt)
826
827  def _benchmarkDynamicLSTMMemorySwapLongSeq(self):
828    """The memory swapping test for the SOSP submission."""
829    print("Calculation: Long LSTM Sequence")
830    print("batch \t len \t units \t dynamic \t elapsed_t \t elapsed_t/len")
831    batch_size = 512
832    seqlen = 800
833    num_units = 512
834    dynamic = True
835    swap_memory = True
836    # Some warming up.
837    if swap_memory:
838      rnn_long_sequence_benchmark(batch_size, seqlen, num_units,
839                                  dynamic, swap_memory, 2)
840    # Measure the performance.
841    for slen in range(100, 1100, 100):
842      rnn_long_sequence_benchmark(batch_size, slen, num_units, dynamic,
843                                  swap_memory, 3)
844
845if __name__ == "__main__":
846  test.main()
847