1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for rnn module.""" 16 17import os 18import time 19import timeit 20 21import numpy as np 22 23from tensorflow.core.protobuf import config_pb2 24from tensorflow.python.client import session 25from tensorflow.python.eager import context 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import ops as ops_lib 29from tensorflow.python.framework import tensor_shape 30from tensorflow.python.framework import test_util 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import control_flow_ops 33from tensorflow.python.ops import gradients_impl 34from tensorflow.python.ops import init_ops 35from tensorflow.python.ops import rnn 36from tensorflow.python.ops import rnn_cell_impl 37from tensorflow.python.ops import tensor_array_ops 38from tensorflow.python.ops import variables as variables_lib 39import tensorflow.python.ops.data_flow_grad # pylint: disable=unused-import 40import tensorflow.python.ops.nn_grad # pylint: disable=unused-import 41import tensorflow.python.ops.sparse_grad # pylint: disable=unused-import 42import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import 43from tensorflow.python.platform import test 44from tensorflow.python.training import saver 45 46 47class Plus1RNNCell(rnn_cell_impl.RNNCell): 48 """RNN Cell generating (output, new_state) = (input + 1, state + 1).""" 49 50 @property 51 def output_size(self): 52 return 5 53 54 @property 55 def state_size(self): 56 return 5 57 58 def call(self, input_, state, scope=None): 59 return (input_ + 1, state + 1) 60 61 62class ScalarStateRNNCell(rnn_cell_impl.RNNCell): 63 """RNN Cell generating (output, new_state) = (input + 1, state + 1).""" 64 65 @property 66 def output_size(self): 67 return 1 68 69 @property 70 def state_size(self): 71 return tensor_shape.TensorShape([]) 72 73 def zero_state(self, batch_size, dtype): 74 return array_ops.zeros([], dtype=dtypes.int32) 75 76 def call(self, input_, state, scope=None): 77 return (input_, state + 1) 78 79 80class UnbalancedOutputRNNCell(rnn_cell_impl.RNNCell): 81 """RNN Cell generating (output, new_state) = (input + 1, state + 1).""" 82 83 @property 84 def output_size(self): 85 return tensor_shape.TensorShape(1), tensor_shape.TensorShape((2)) 86 87 @property 88 def state_size(self): 89 return tensor_shape.TensorShape([]) 90 91 def zero_state(self, batch_size, dtype): 92 return array_ops.zeros([], dtype=dtypes.int32) 93 94 def call(self, input_, state, scope=None): 95 concatenated = array_ops.concat((input_, input_), axis=-1) 96 return (input_, concatenated), state + 1 97 98 99class TensorArrayStateRNNCell(rnn_cell_impl.RNNCell): 100 """RNN Cell its state as a TensorArray.""" 101 102 @property 103 def output_size(self): 104 return 1 105 106 @property 107 def state_size(self): 108 return (tensor_shape.TensorShape([]), ()) 109 110 def zero_state(self, batch_size, dtype): 111 return (array_ops.zeros([], dtype=dtypes.int32), 112 tensor_array_ops.TensorArray( 113 dtype=dtype, size=0, dynamic_size=True)) 114 115 def call(self, input_, state, scope=None): 116 new_array = state[1].write(state[0], input_) 117 return (input_, (state[0] + 1, new_array)) 118 119 120class RNNTest(test.TestCase): 121 122 def setUp(self): 123 self._seed = 23489 124 np.random.seed(self._seed) 125 126 @test_util.run_in_graph_and_eager_modes 127 def testInvalidSequenceLengthShape(self): 128 cell = Plus1RNNCell() 129 if context.executing_eagerly(): 130 inputs = [constant_op.constant(np.ones((3, 4)))] 131 else: 132 inputs = [array_ops.placeholder(dtypes.float32, shape=(3, 4))] 133 with self.assertRaisesRegex(ValueError, "must be a vector"): 134 rnn.dynamic_rnn( 135 cell, 136 array_ops.stack(inputs), 137 dtype=dtypes.float32, 138 sequence_length=[[4]]) 139 140 @test_util.run_in_graph_and_eager_modes 141 def testInvalidDtype(self): 142 if context.executing_eagerly(): 143 inputs = np.zeros((3, 4, 5), dtype=np.int32) 144 else: 145 inputs = array_ops.placeholder(dtypes.int32, shape=(3, 4, 5)) 146 147 cells = [ 148 rnn_cell_impl.BasicRNNCell, 149 rnn_cell_impl.GRUCell, 150 rnn_cell_impl.BasicLSTMCell, 151 rnn_cell_impl.LSTMCell, 152 ] 153 for cell_cls in cells: 154 with self.cached_session(): 155 with self.assertRaisesRegex(ValueError, 156 "RNN cell only supports floating"): 157 cell = cell_cls(2, dtype=dtypes.int32) 158 rnn.dynamic_rnn(cell, inputs, dtype=dtypes.int32) 159 160 @test_util.run_in_graph_and_eager_modes 161 def testBatchSizeFromInput(self): 162 cell = Plus1RNNCell() 163 in_eager_mode = context.executing_eagerly() 164 # With static batch size 165 if in_eager_mode: 166 inputs = np.zeros((3, 4, 5), dtype=np.float32) 167 initial_state = np.zeros((3, 5), dtype=np.float32) 168 else: 169 inputs = array_ops.placeholder(dtypes.float32, shape=(3, 4, 5)) 170 initial_state = array_ops.placeholder(dtypes.float32, shape=(3, 5)) 171 172 # - Without initial_state 173 outputs, state = rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32) 174 self.assertEqual(3, outputs.shape[0]) 175 self.assertEqual(3, state.shape[0]) 176 177 # - With initial_state 178 outputs, state = rnn.dynamic_rnn( 179 cell, inputs, initial_state=initial_state) 180 self.assertEqual(3, outputs.shape[0]) 181 self.assertEqual(3, state.shape[0]) 182 183 # Without static batch size 184 # Tensor shapes are fully determined with eager execution enabled, 185 # so only run this test for graph construction. 186 if not in_eager_mode: 187 inputs = array_ops.placeholder(dtypes.float32, shape=(None, 4, 5)) 188 # - Without initial_state 189 outputs, state = rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32) 190 self.assertEqual(None, outputs.shape.dims[0].value) 191 self.assertEqual(None, state.shape.dims[0].value) 192 # - With initial_state 193 outputs, state = rnn.dynamic_rnn( 194 cell, 195 inputs, 196 initial_state=array_ops.placeholder(dtypes.float32, shape=(None, 5))) 197 self.assertEqual(None, outputs.shape.dims[0].value) 198 self.assertEqual(None, state.shape.dims[0].value) 199 200 @test_util.run_in_graph_and_eager_modes 201 def testScalarStateIsAccepted(self): 202 cell = ScalarStateRNNCell() 203 in_eager_mode = context.executing_eagerly() 204 205 if in_eager_mode: 206 inputs = np.array([[[1], [2], [3], [4]]], dtype=np.float32) 207 else: 208 inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1)) 209 210 with self.cached_session() as sess: 211 outputs, state = rnn.dynamic_rnn( 212 cell, inputs, dtype=dtypes.float32, sequence_length=[4]) 213 if not in_eager_mode: 214 outputs, state = sess.run( 215 [outputs, state], feed_dict={inputs: [[[1], [2], [3], [4]]]}) 216 217 self.assertAllEqual([[[1], [2], [3], [4]]], outputs) 218 self.assertAllEqual(4, state) 219 220 @test_util.run_in_graph_and_eager_modes 221 def testUnbalancedOutputIsAccepted(self): 222 cell = UnbalancedOutputRNNCell() 223 in_eager_mode = context.executing_eagerly() 224 225 if in_eager_mode: 226 inputs = np.array([[[1], [2], [3], [4]]], dtype=np.float32) 227 else: 228 inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1)) 229 230 with self.cached_session() as sess: 231 outputs, state = rnn.dynamic_rnn( 232 cell, inputs, dtype=dtypes.float32, sequence_length=[4]) 233 if not in_eager_mode: 234 outputs, state = sess.run( 235 [outputs, state], feed_dict={inputs: [[[1], [2], [3], [4]]]}) 236 237 self.assertIsInstance(outputs, tuple) 238 self.assertAllEqual([[[1], [2], [3], [4]]], outputs[0]) 239 self.assertAllEqual([[[1, 1], [2, 2], [3, 3], [4, 4]]], outputs[1]) 240 self.assertAllEqual(4, state) 241 242 @test_util.assert_no_new_pyobjects_executing_eagerly 243 def testEagerMemory(self): 244 with context.eager_mode(): 245 cell = TensorArrayStateRNNCell() 246 inputs = np.array([[[1], [2], [3], [4]]], dtype=np.float32) 247 rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32, sequence_length=[4]) 248 249 @test_util.run_in_graph_and_eager_modes 250 @test_util.run_v1_only("b/120545219") 251 def testTensorArrayStateIsAccepted(self): 252 cell = TensorArrayStateRNNCell() 253 in_eager_mode = context.executing_eagerly() 254 255 if in_eager_mode: 256 inputs = np.array([[[1], [2], [3], [4]]], dtype=np.float32) 257 else: 258 inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1)) 259 260 with self.cached_session() as sess: 261 outputs, state = rnn.dynamic_rnn( 262 cell, inputs, dtype=dtypes.float32, sequence_length=[4]) 263 state = (state[0], state[1].stack()) 264 if not in_eager_mode: 265 outputs, state = sess.run( 266 [outputs, state], feed_dict={ 267 inputs: [[[1], [2], [3], [4]]] 268 }) 269 270 self.assertAllEqual([[[1], [2], [3], [4]]], outputs) 271 self.assertAllEqual(4, state[0]) 272 self.assertAllEqual([[[1]], [[2]], [[3]], [[4]]], state[1]) 273 274 @test_util.run_deprecated_v1 275 def testCellGetInitialState(self): 276 cell = rnn_cell_impl.BasicRNNCell(5) 277 with self.assertRaisesRegex(ValueError, 278 "batch_size and dtype cannot be None"): 279 cell.get_initial_state(None, None, None) 280 281 inputs = array_ops.placeholder(dtypes.float32, shape=(None, 4, 1)) 282 with self.assertRaisesRegex( 283 ValueError, "batch size from input tensor is different from"): 284 cell.get_initial_state(inputs=inputs, batch_size=50, dtype=None) 285 286 with self.assertRaisesRegex( 287 ValueError, "batch size from input tensor is different from"): 288 cell.get_initial_state( 289 inputs=inputs, batch_size=constant_op.constant(50), dtype=None) 290 291 with self.assertRaisesRegex(ValueError, 292 "dtype from input tensor is different from"): 293 cell.get_initial_state(inputs=inputs, batch_size=None, dtype=dtypes.int16) 294 295 initial_state = cell.get_initial_state( 296 inputs=inputs, batch_size=None, dtype=None) 297 self.assertEqual(initial_state.shape.as_list(), [None, 5]) 298 self.assertEqual(initial_state.dtype, inputs.dtype) 299 300 batch = array_ops.shape(inputs)[0] 301 dtype = inputs.dtype 302 initial_state = cell.get_initial_state(None, batch, dtype) 303 self.assertEqual(initial_state.shape.as_list(), [None, 5]) 304 self.assertEqual(initial_state.dtype, inputs.dtype) 305 306 def _assert_cell_builds(self, cell_class, dtype, batch_size, in_size, 307 out_size): 308 cell = cell_class(out_size, dtype=dtype) 309 in_shape = tensor_shape.TensorShape((batch_size, in_size)) 310 cell.build(in_shape) 311 state_output = cell.get_initial_state( 312 inputs=None, batch_size=batch_size, dtype=dtype) 313 cell_output, _ = cell(array_ops.zeros(in_shape, dtype), state_output) 314 self.assertAllEqual([batch_size, out_size], cell_output.shape.as_list()) 315 316 @test_util.run_in_graph_and_eager_modes 317 def testCellsBuild(self): 318 f32 = dtypes.float32 319 f64 = dtypes.float64 320 self._assert_cell_builds(rnn_cell_impl.BasicRNNCell, f32, 5, 7, 3) 321 self._assert_cell_builds(rnn_cell_impl.BasicRNNCell, f64, 5, 7, 3) 322 self._assert_cell_builds(rnn_cell_impl.BasicLSTMCell, f32, 5, 7, 3) 323 self._assert_cell_builds(rnn_cell_impl.BasicLSTMCell, f64, 5, 7, 3) 324 self._assert_cell_builds(rnn_cell_impl.GRUCell, f32, 5, 7, 3) 325 self._assert_cell_builds(rnn_cell_impl.GRUCell, f64, 5, 7, 3) 326 self._assert_cell_builds(rnn_cell_impl.LSTMCell, f32, 5, 7, 3) 327 self._assert_cell_builds(rnn_cell_impl.LSTMCell, f64, 5, 7, 3) 328 329 @test_util.run_deprecated_v1 330 def testBasicLSTMCellInterchangeWithLSTMCell(self): 331 with self.session(graph=ops_lib.Graph()) as sess: 332 basic_cell = rnn_cell_impl.BasicLSTMCell(1) 333 basic_cell(array_ops.ones([1, 1]), 334 state=basic_cell.get_initial_state(inputs=None, 335 batch_size=1, 336 dtype=dtypes.float32)) 337 self.evaluate([v.initializer for v in basic_cell.variables]) 338 self.evaluate(basic_cell._bias.assign([10.] * 4)) 339 save = saver.Saver() 340 prefix = os.path.join(self.get_temp_dir(), "ckpt") 341 save_path = save.save(sess, prefix) 342 343 with self.session(graph=ops_lib.Graph()) as sess: 344 lstm_cell = rnn_cell_impl.LSTMCell(1, name="basic_lstm_cell") 345 lstm_cell(array_ops.ones([1, 1]), 346 state=lstm_cell.get_initial_state(inputs=None, 347 batch_size=1, 348 dtype=dtypes.float32)) 349 self.evaluate([v.initializer for v in lstm_cell.variables]) 350 save = saver.Saver() 351 save.restore(sess, save_path) 352 self.assertAllEqual([10.] * 4, self.evaluate(lstm_cell._bias)) 353 354######### Benchmarking RNN code 355 356 357def _static_vs_dynamic_rnn_benchmark_static(inputs_list_t, sequence_length): 358 (_, input_size) = inputs_list_t[0].get_shape().as_list() 359 initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=127) 360 cell = rnn_cell_impl.LSTMCell( 361 num_units=input_size, 362 use_peepholes=True, 363 initializer=initializer, 364 state_is_tuple=False) 365 outputs, final_state = rnn.static_rnn( 366 cell, 367 inputs_list_t, 368 sequence_length=sequence_length, 369 dtype=dtypes.float32) 370 371 trainable_variables = ops_lib.get_collection( 372 ops_lib.GraphKeys.TRAINABLE_VARIABLES) 373 gradients = gradients_impl.gradients(outputs + [final_state], 374 trainable_variables) 375 376 return control_flow_ops.group(final_state, *(gradients + outputs)) 377 378 379def _static_vs_dynamic_rnn_benchmark_dynamic(inputs_t, sequence_length): 380 (unused_0, unused_1, input_size) = inputs_t.get_shape().as_list() 381 initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=127) 382 cell = rnn_cell_impl.LSTMCell( 383 num_units=input_size, 384 use_peepholes=True, 385 initializer=initializer, 386 state_is_tuple=False) 387 outputs, final_state = rnn.dynamic_rnn( 388 cell, inputs_t, sequence_length=sequence_length, dtype=dtypes.float32) 389 390 trainable_variables = ops_lib.get_collection( 391 ops_lib.GraphKeys.TRAINABLE_VARIABLES) 392 gradients = gradients_impl.gradients([outputs, final_state], 393 trainable_variables) 394 395 return control_flow_ops.group(final_state, outputs, *gradients) 396 397 398def graph_creation_static_vs_dynamic_rnn_benchmark(max_time): 399 config = config_pb2.ConfigProto() 400 config.allow_soft_placement = True 401 402 # These parameters don't matter 403 batch_size = 512 404 num_units = 512 405 406 # Set up sequence lengths 407 np.random.seed([127]) 408 sequence_length = np.random.randint(0, max_time, size=batch_size) 409 inputs_list = [ 410 np.random.randn(batch_size, num_units).astype(np.float32) 411 for _ in range(max_time) 412 ] 413 inputs = np.dstack(inputs_list).transpose([0, 2, 1]) # batch x time x depth 414 415 def _create_static_rnn(): 416 with session.Session(config=config, graph=ops_lib.Graph()): 417 inputs_list_t = [ 418 variables_lib.Variable( 419 x, trainable=False).value() for x in inputs_list 420 ] 421 _static_vs_dynamic_rnn_benchmark_static(inputs_list_t, sequence_length) 422 423 def _create_dynamic_rnn(): 424 with session.Session(config=config, graph=ops_lib.Graph()): 425 inputs_t = variables_lib.Variable(inputs, trainable=False).value() 426 _static_vs_dynamic_rnn_benchmark_dynamic(inputs_t, sequence_length) 427 428 delta_static = timeit.timeit(_create_static_rnn, number=5) 429 delta_dynamic = timeit.timeit(_create_dynamic_rnn, number=5) 430 431 print("%d \t %f \t %f \t %f" % 432 (max_time, delta_static, delta_dynamic, delta_dynamic / delta_static)) 433 return delta_static, delta_dynamic 434 435 436def _timer(sess, ops): 437 # Warm in 438 for _ in range(2): 439 sess.run(ops) 440 441 # Timing run 442 runs = 20 443 start = time.time() 444 for _ in range(runs): 445 sess.run(ops) 446 end = time.time() 447 return (end - start) / float(runs) 448 449 450def static_vs_dynamic_rnn_benchmark(batch_size, max_time, num_units, use_gpu): 451 config = config_pb2.ConfigProto() 452 config.allow_soft_placement = True 453 454 # Set up sequence lengths 455 np.random.seed([127]) 456 sequence_length = np.random.randint(0, max_time, size=batch_size) 457 inputs_list = [ 458 np.random.randn(batch_size, num_units).astype(np.float32) 459 for _ in range(max_time) 460 ] 461 inputs = np.dstack(inputs_list).transpose([0, 2, 1]) # batch x time x depth 462 463 # Using rnn() 464 with session.Session(config=config, graph=ops_lib.Graph()) as sess: 465 with ops_lib.device("/cpu:0" if not use_gpu else None): 466 inputs_list_t = [ 467 variables_lib.Variable( 468 x, trainable=False).value() for x in inputs_list 469 ] 470 ops = _static_vs_dynamic_rnn_benchmark_static(inputs_list_t, 471 sequence_length) 472 variables_lib.global_variables_initializer().run() 473 delta_static = _timer(sess, ops) 474 475 # Using dynamic_rnn() 476 with session.Session(config=config, graph=ops_lib.Graph()) as sess: 477 with ops_lib.device("/cpu:0" if not use_gpu else None): 478 inputs_t = variables_lib.Variable(inputs, trainable=False).value() 479 ops = _static_vs_dynamic_rnn_benchmark_dynamic(inputs_t, sequence_length) 480 variables_lib.global_variables_initializer().run() 481 delta_dynamic = _timer(sess, ops) 482 483 print("%d \t %d \t %d \t %s \t %f \t %f \t %f" % 484 (batch_size, max_time, num_units, use_gpu, delta_static, delta_dynamic, 485 delta_dynamic / delta_static)) 486 487 return delta_static, delta_dynamic 488 489 490def _half_seq_len_vs_unroll_half_rnn_benchmark(inputs_list_t, sequence_length): 491 (_, input_size) = inputs_list_t[0].get_shape().as_list() 492 initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=127) 493 cell = rnn_cell_impl.LSTMCell( 494 num_units=input_size, 495 use_peepholes=True, 496 initializer=initializer, 497 state_is_tuple=False) 498 outputs, final_state = rnn.static_rnn( 499 cell, 500 inputs_list_t, 501 sequence_length=sequence_length, 502 dtype=dtypes.float32) 503 504 trainable_variables = ops_lib.get_collection( 505 ops_lib.GraphKeys.TRAINABLE_VARIABLES) 506 gradients = gradients_impl.gradients(outputs + [final_state], 507 trainable_variables) 508 509 return control_flow_ops.group(final_state, *(gradients + outputs)) 510 511 512def half_seq_len_vs_unroll_half_rnn_benchmark(batch_size, max_time, num_units, 513 use_gpu): 514 config = config_pb2.ConfigProto() 515 config.allow_soft_placement = True 516 517 # Set up sequence lengths 518 np.random.seed([127]) 519 sequence_length = max_time * np.ones((batch_size,)) 520 inputs_list = [ 521 np.random.randn(batch_size, num_units).astype(np.float32) 522 for _ in range(max_time) 523 ] 524 525 # Halve the sequence length, full static unroll 526 with session.Session(config=config, graph=ops_lib.Graph()) as sess: 527 with ops_lib.device("/cpu:0" if not use_gpu else None): 528 inputs_list_t = [ 529 variables_lib.Variable( 530 x, trainable=False).value() for x in inputs_list 531 ] 532 ops = _half_seq_len_vs_unroll_half_rnn_benchmark(inputs_list_t, 533 sequence_length / 2) 534 variables_lib.global_variables_initializer().run() 535 delta_half_seq_len = _timer(sess, ops) 536 537 # Halve the unroll size, don't use sequence length 538 with session.Session(config=config, graph=ops_lib.Graph()) as sess: 539 with ops_lib.device("/cpu:0" if not use_gpu else None): 540 inputs_list_t = [ 541 variables_lib.Variable( 542 x, trainable=False).value() for x in inputs_list 543 ] 544 ops = _half_seq_len_vs_unroll_half_rnn_benchmark( 545 inputs_list_t[:(max_time // 2)], sequence_length / 2) 546 variables_lib.global_variables_initializer().run() 547 delta_unroll_half = _timer(sess, ops) 548 print("%d \t %d \t\t %d \t %s \t %f \t\t %f \t\t %f" % 549 (batch_size, max_time, num_units, use_gpu, delta_half_seq_len, 550 delta_unroll_half, delta_half_seq_len / delta_unroll_half)) 551 552 return delta_half_seq_len, delta_unroll_half 553 554 555def _concat_state_vs_tuple_state_rnn_benchmark(inputs_list_t, sequence_length, 556 state_is_tuple): 557 (_, input_size) = inputs_list_t[0].get_shape().as_list() 558 initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=127) 559 cell = rnn_cell_impl.LSTMCell( 560 num_units=input_size, 561 use_peepholes=True, 562 initializer=initializer, 563 state_is_tuple=state_is_tuple) 564 outputs, final_state = rnn.static_rnn( 565 cell, 566 inputs_list_t, 567 sequence_length=sequence_length, 568 dtype=dtypes.float32) 569 570 final_state = list(final_state) if state_is_tuple else [final_state] 571 572 trainable_variables = ops_lib.get_collection( 573 ops_lib.GraphKeys.TRAINABLE_VARIABLES) 574 gradients = gradients_impl.gradients(outputs + final_state, 575 trainable_variables) 576 577 return control_flow_ops.group(*(final_state + gradients + outputs)) 578 579 580def concat_state_vs_tuple_state_rnn_benchmark(batch_size, max_time, num_units, 581 use_gpu): 582 config = config_pb2.ConfigProto() 583 config.allow_soft_placement = True 584 585 # Set up sequence lengths 586 np.random.seed([127]) 587 sequence_length = max_time * np.ones((batch_size,)) 588 inputs_list = [ 589 np.random.randn(batch_size, num_units).astype(np.float32) 590 for _ in range(max_time) 591 ] 592 593 # Run with concatenated states (default) 594 with session.Session(config=config, graph=ops_lib.Graph()) as sess: 595 with ops_lib.device("/cpu:0" if not use_gpu else None): 596 inputs_list_t = [ 597 variables_lib.Variable( 598 x, trainable=False).value() for x in inputs_list 599 ] 600 ops = _concat_state_vs_tuple_state_rnn_benchmark( 601 inputs_list_t, sequence_length, state_is_tuple=False) 602 variables_lib.global_variables_initializer().run() 603 delta_concat_state = _timer(sess, ops) 604 605 # Run with tuple states (new) 606 with session.Session(config=config, graph=ops_lib.Graph()) as sess: 607 with ops_lib.device("/cpu:0" if not use_gpu else None): 608 inputs_list_t = [ 609 variables_lib.Variable( 610 x, trainable=False).value() for x in inputs_list 611 ] 612 ops = _concat_state_vs_tuple_state_rnn_benchmark( 613 inputs_list_t, sequence_length, state_is_tuple=True) 614 variables_lib.global_variables_initializer().run() 615 delta_tuple_state = _timer(sess, ops) 616 print("%d \t %d \t %d \t %s \t %f \t\t %f \t\t %f" % 617 (batch_size, max_time, num_units, use_gpu, delta_concat_state, 618 delta_tuple_state, delta_concat_state / delta_tuple_state)) 619 620 return delta_concat_state, delta_tuple_state 621 622 623def _dynamic_rnn_swap_memory_benchmark(inputs_t, sequence_length, swap_memory): 624 (unused_0, unused_1, input_size) = inputs_t.get_shape().as_list() 625 initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=127) 626 cell = rnn_cell_impl.LSTMCell( 627 num_units=input_size, 628 use_peepholes=True, 629 initializer=initializer, 630 state_is_tuple=False) 631 outputs, final_state = rnn.dynamic_rnn( 632 cell, 633 inputs_t, 634 sequence_length=sequence_length, 635 swap_memory=swap_memory, 636 dtype=dtypes.float32) 637 638 trainable_variables = ops_lib.get_collection( 639 ops_lib.GraphKeys.TRAINABLE_VARIABLES) 640 gradients = gradients_impl.gradients([outputs, final_state], 641 trainable_variables) 642 643 return control_flow_ops.group(final_state, outputs, *gradients) 644 645 646def dynamic_rnn_swap_memory_benchmark(batch_size, max_time, num_units): 647 config = config_pb2.ConfigProto() 648 config.allow_soft_placement = True 649 650 # Set up sequence lengths 651 np.random.seed([127]) 652 sequence_length = np.random.randint(0, max_time, size=batch_size) 653 inputs_list = [ 654 np.random.randn(batch_size, num_units).astype(np.float32) 655 for _ in range(max_time) 656 ] 657 inputs = np.dstack(inputs_list).transpose([0, 2, 1]) # batch x time x depth 658 659 # No memory swap 660 with session.Session(config=config, graph=ops_lib.Graph()) as sess: 661 inputs_t = variables_lib.Variable(inputs, trainable=False).value() 662 ops = _dynamic_rnn_swap_memory_benchmark( 663 inputs_t, sequence_length, swap_memory=False) 664 variables_lib.global_variables_initializer().run() 665 no_swap = _timer(sess, ops) 666 667 # Memory swap 668 with session.Session(config=config, graph=ops_lib.Graph()) as sess: 669 inputs_t = variables_lib.Variable(inputs, trainable=False).value() 670 ops = _dynamic_rnn_swap_memory_benchmark( 671 inputs_t, sequence_length, swap_memory=True) 672 variables_lib.global_variables_initializer().run() 673 swap = _timer(sess, ops) 674 675 print("%d \t %d \t %d \t %f \t %f \t %f" % 676 (batch_size, max_time, num_units, no_swap, swap, swap / no_swap)) 677 return no_swap, swap 678 679 680def rnn_long_sequence_benchmark(batch_size, seqlen, num_units, dynamic, 681 swap_memory, nn): 682 config = config_pb2.ConfigProto() 683 config.allow_soft_placement = True 684 685 # Set up sequence lengths 686 np.random.seed([127]) 687 sequence_length = [seqlen for _ in range(batch_size)] 688 inputs_list = [ 689 np.random.randn(batch_size, num_units).astype(np.float32) 690 for _ in range(seqlen) 691 ] 692 inputs = np.dstack(inputs_list).transpose([0, 2, 1]) # batch x time x depth 693 694 for _ in range(nn): 695 if dynamic: 696 with session.Session(config=config, graph=ops_lib.Graph()) as sess: 697 inputs_t = variables_lib.Variable(inputs, trainable=False).value() 698 ops = _dynamic_rnn_swap_memory_benchmark( 699 inputs_t, sequence_length, swap_memory=swap_memory) 700 variables_lib.global_variables_initializer().run() 701 elapsed = _timer(sess, ops) 702 else: 703 with session.Session(config=config, graph=ops_lib.Graph()) as sess: 704 inputs_list_t = [ 705 variables_lib.Variable( 706 x, trainable=False).value() for x in inputs_list 707 ] 708 ops = _static_vs_dynamic_rnn_benchmark_static(inputs_list_t, 709 sequence_length) 710 variables_lib.global_variables_initializer().run() 711 elapsed = _timer(sess, ops) 712 713 print("%d \t %d \t %d \t %s \t %f \t %f" % (batch_size, seqlen, num_units, 714 dynamic, elapsed, 715 elapsed / seqlen)) 716 717 718class BenchmarkRNN(test.Benchmark): 719 720 def benchmarkGraphCreationStaticVsDynamicLSTM(self): 721 print("Graph Creation: Static Unroll vs. Dynamic Unroll LSTM") 722 print("max_t \t dt(static) \t dt(dynamic) \t dt(dynamic)/dt(static)") 723 for max_time in (1, 25, 50): 724 s_dt, d_dt = graph_creation_static_vs_dynamic_rnn_benchmark(max_time) 725 self.report_benchmark( 726 name="graph_creation_time_static_T%02d" % max_time, 727 iters=5, 728 wall_time=s_dt) 729 self.report_benchmark( 730 name="graph_creation_time_dynamic_T%02d" % max_time, 731 iters=5, 732 wall_time=d_dt) 733 734 def benchmarkStaticUnrollVsDynamicFlowLSTM(self): 735 print("Calculation: Static Unroll with Dynamic Flow LSTM " 736 "vs. Dynamic Unroll LSTM") 737 print("batch \t max_t \t units \t gpu \t dt(static) \t dt(dynamic) " 738 "\t dt(dynamic)/dt(static)") 739 for batch_size in (256,): 740 for max_time in (50,): 741 for num_units in (512, 256, 128): 742 for use_gpu in (False, True): 743 s_dt, d_dt = static_vs_dynamic_rnn_benchmark(batch_size, max_time, 744 num_units, use_gpu) 745 self.report_benchmark( 746 name="static_unroll_time_T%02d_B%03d_N%03d_gpu_%s" % 747 (max_time, batch_size, num_units, use_gpu), 748 iters=20, 749 wall_time=s_dt) 750 self.report_benchmark( 751 name="dynamic_unroll_time_T%02d_B%03d_N%03d_gpu_%s" % 752 (max_time, batch_size, num_units, use_gpu), 753 iters=20, 754 wall_time=d_dt) 755 756 def benchmarkDynamicLSTMNoMemorySwapVsMemorySwap(self): 757 print("Calculation: Dynamic LSTM No Memory Swap vs. Memory Swap") 758 print("batch \t max_t \t units \t no_swap \t swap \t swap/no_swap") 759 for batch_size in (256, 512): 760 for max_time in (100,): 761 for num_units in (512, 256, 128): 762 no_swap, swap = dynamic_rnn_swap_memory_benchmark(batch_size, 763 max_time, num_units) 764 self.report_benchmark( 765 name="dynamic_lstm_no_memory_swap_T%02d_B%03d_N%03d" % 766 (max_time, batch_size, num_units), 767 iters=20, 768 wall_time=no_swap) 769 self.report_benchmark( 770 name="dynamic_lstm_with_memory_swap_T%02d_B%03d_N%03d" % 771 (max_time, batch_size, num_units), 772 iters=20, 773 wall_time=swap) 774 775 def benchmarkStaticUnrollHalfSequenceLengthVsHalfUnroll(self): 776 print("Calculation: Static Unroll with Halved Sequence Length " 777 "vs. Half Static Unroll") 778 print("batch \t full_t \t units \t gpu \t dt(half_seq_len) " 779 "\t dt(unroll_half) \t dt(half_seq_len)/dt(unroll_half)") 780 for batch_size in (128,): 781 for max_time in (50,): 782 for num_units in (256,): 783 for use_gpu in (False, True): 784 s_dt, d_dt = half_seq_len_vs_unroll_half_rnn_benchmark(batch_size, 785 max_time, 786 num_units, 787 use_gpu) 788 self.report_benchmark( 789 name="half_seq_len_time_T%02d_B%03d_N%03d_gpu_%s" % 790 (max_time, batch_size, num_units, use_gpu), 791 iters=20, 792 wall_time=s_dt) 793 self.report_benchmark( 794 name="unroll_half_time_T%02d_B%03d_N%03d_gpu_%s" % 795 (max_time, batch_size, num_units, use_gpu), 796 iters=20, 797 wall_time=d_dt) 798 799 def benchmarkStaticUnrollStateConcatVsStateTuple(self): 800 print("Calculation: Static Unroll with Concatenated State " 801 "vs. Tuple State") 802 print("batch \t time \t units \t gpu \t dt(concat_state) " 803 "\t dt(tuple_state) \t dt(concat_state)/dt(tuple_state)") 804 for batch_size in ( 805 16, 806 128,): 807 for max_time in (50,): 808 for num_units in ( 809 16, 810 128,): 811 for use_gpu in (False, True): 812 c_dt, t_dt = concat_state_vs_tuple_state_rnn_benchmark(batch_size, 813 max_time, 814 num_units, 815 use_gpu) 816 self.report_benchmark( 817 name="concat_state_time_T%02d_B%03d_N%03d_gpu_%s" % 818 (max_time, batch_size, num_units, use_gpu), 819 iters=20, 820 wall_time=c_dt) 821 self.report_benchmark( 822 name="tuple_state_time_T%02d_B%03d_N%03d_gpu_%s" % 823 (max_time, batch_size, num_units, use_gpu), 824 iters=20, 825 wall_time=t_dt) 826 827 def _benchmarkDynamicLSTMMemorySwapLongSeq(self): 828 """The memory swapping test for the SOSP submission.""" 829 print("Calculation: Long LSTM Sequence") 830 print("batch \t len \t units \t dynamic \t elapsed_t \t elapsed_t/len") 831 batch_size = 512 832 seqlen = 800 833 num_units = 512 834 dynamic = True 835 swap_memory = True 836 # Some warming up. 837 if swap_memory: 838 rnn_long_sequence_benchmark(batch_size, seqlen, num_units, 839 dynamic, swap_memory, 2) 840 # Measure the performance. 841 for slen in range(100, 1100, 100): 842 rnn_long_sequence_benchmark(batch_size, slen, num_units, dynamic, 843 swap_memory, 3) 844 845if __name__ == "__main__": 846 test.main() 847