1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for RNN cells.""" 16 17import itertools 18import os 19 20from absl.testing import parameterized 21import numpy as np 22 23from tensorflow.core.protobuf import config_pb2 24from tensorflow.python.checkpoint import checkpoint as trackable_utils 25from tensorflow.python.eager import context 26from tensorflow.python.eager import def_function 27from tensorflow.python.framework import constant_op 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import errors_impl 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import random_seed 32from tensorflow.python.framework import tensor_shape 33from tensorflow.python.framework import tensor_spec 34from tensorflow.python.framework import test_util 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import control_flow_ops 37from tensorflow.python.ops import gen_rnn_ops 38from tensorflow.python.ops import gradients_impl 39from tensorflow.python.ops import init_ops 40from tensorflow.python.ops import math_ops 41from tensorflow.python.ops import rnn 42from tensorflow.python.ops import rnn_cell 43from tensorflow.python.ops import rnn_cell_impl 44from tensorflow.python.ops import state_ops 45from tensorflow.python.ops import tensor_array_ops 46from tensorflow.python.ops import variable_scope 47from tensorflow.python.ops import variables as variables_lib 48from tensorflow.python.platform import test 49from tensorflow.python.platform import tf_logging 50from tensorflow.python.saved_model import load 51from tensorflow.python.saved_model import save 52from tensorflow.python.trackable import autotrackable 53from tensorflow.python.util import nest 54 55 56class Plus1RNNCell(rnn_cell.RNNCell): 57 """RNN Cell generating (output, new_state) = (input + 1, state + 1).""" 58 59 @property 60 def output_size(self): 61 return 5 62 63 @property 64 def state_size(self): 65 return 5 66 67 def __call__(self, input_, state, scope=None): 68 return (input_ + 1, state + 1) 69 70 71class DummyMultiDimensionalLSTM(rnn_cell.RNNCell): 72 """LSTM Cell generating (output, new_state) = (input + 1, state + 1). 73 74 The input to this cell may have an arbitrary number of dimensions that follow 75 the preceding 'Time' and 'Batch' dimensions. 76 """ 77 78 def __init__(self, dims): 79 """Initialize the Multi-dimensional LSTM cell. 80 81 Args: 82 dims: tuple that contains the dimensions of the output of the cell, 83 without including 'Time' or 'Batch' dimensions. 84 """ 85 if not isinstance(dims, tuple): 86 raise TypeError("The dimensions passed to DummyMultiDimensionalLSTM " 87 "should be a tuple of ints.") 88 self._dims = dims 89 self._output_size = tensor_shape.TensorShape(self._dims) 90 self._state_size = (tensor_shape.TensorShape(self._dims), 91 tensor_shape.TensorShape(self._dims)) 92 93 @property 94 def output_size(self): 95 return self._output_size 96 97 @property 98 def state_size(self): 99 return self._state_size 100 101 def __call__(self, input_, state, scope=None): 102 h, c = state 103 return (input_ + 1, (h + 1, c + 1)) 104 105 106class NestedRNNCell(rnn_cell.RNNCell): 107 """RNN Cell generating (output, new_state) = (input + 1, state + 1). 108 109 The input, output and state of this cell is a tuple of two tensors. 110 """ 111 112 @property 113 def output_size(self): 114 return (5, 5) 115 116 @property 117 def state_size(self): 118 return (6, 6) 119 120 def __call__(self, input_, state, scope=None): 121 h, c = state 122 x, y = input_ 123 return ((x + 1, y + 1), (h + 1, c + 1)) 124 125 126class TestStateSaver(object): 127 128 def __init__(self, batch_size, state_size): 129 self._batch_size = batch_size 130 self._state_size = state_size 131 self.saved_state = {} 132 133 def state(self, name): 134 135 if isinstance(self._state_size, dict): 136 state_size = self._state_size[name] 137 else: 138 state_size = self._state_size 139 if isinstance(state_size, int): 140 state_size = (state_size,) 141 elif isinstance(state_size, tuple): 142 pass 143 else: 144 raise TypeError("state_size should either be an int or a tuple") 145 146 return array_ops.zeros((self._batch_size,) + state_size) 147 148 def save_state(self, name, state): 149 self.saved_state[name] = state 150 return array_ops.identity(state) 151 152 @property 153 def batch_size(self): 154 return self._batch_size 155 156 @property 157 def state_size(self): 158 return self._state_size 159 160 161class TestStateSaverWithCounters(TestStateSaver): 162 """Class wrapper around TestStateSaver. 163 164 A dummy class used for testing of static_state_saving_rnn. It helps test if 165 save_state and state functions got called same number of time when we 166 evaluate output of rnn cell and state or either of them separately. It 167 inherits from the TestStateSaver and adds the counters for calls of functions. 168 """ 169 170 @test_util.run_v1_only("b/124229375") 171 def __init__(self, batch_size, state_size): 172 super(TestStateSaverWithCounters, self).__init__(batch_size, state_size) 173 self._num_state_calls = variables_lib.VariableV1(0) 174 self._num_save_state_calls = variables_lib.VariableV1(0) 175 176 def state(self, name): 177 with ops.control_dependencies( 178 [state_ops.assign_add(self._num_state_calls, 1)]): 179 return super(TestStateSaverWithCounters, self).state(name) 180 181 def save_state(self, name, state): 182 with ops.control_dependencies([state_ops.assign_add( 183 self._num_save_state_calls, 1)]): 184 return super(TestStateSaverWithCounters, self).save_state(name, state) 185 186 @property 187 def num_state_calls(self): 188 return self._num_state_calls 189 190 @property 191 def num_save_state_calls(self): 192 return self._num_save_state_calls 193 194 195class RNNTest(test.TestCase): 196 197 def setUp(self): 198 self._seed = 23489 199 np.random.seed(self._seed) 200 201 @test_util.run_v1_only("b/124229375") 202 def testInvalidSequenceLengthShape(self): 203 cell = Plus1RNNCell() 204 inputs = [array_ops.placeholder(dtypes.float32, shape=(3, 4))] 205 with self.assertRaisesRegex(ValueError, "must be a vector"): 206 rnn.static_rnn(cell, inputs, dtype=dtypes.float32, sequence_length=4) 207 208 @test_util.run_v1_only("b/124229375") 209 def testRNN(self): 210 cell = Plus1RNNCell() 211 batch_size = 2 212 input_size = 5 213 max_length = 8 # unrolled up to this length 214 inputs = max_length * [ 215 array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size)) 216 ] 217 outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32) 218 self.assertEqual(len(outputs), len(inputs)) 219 for out, inp in zip(outputs, inputs): 220 self.assertEqual(out.get_shape(), inp.get_shape()) 221 self.assertEqual(out.dtype, inp.dtype) 222 223 with self.session() as sess: 224 input_value = np.random.randn(batch_size, input_size) 225 values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value}) 226 227 # Outputs 228 for v in values[:-1]: 229 self.assertAllClose(v, input_value + 1.0) 230 231 # Final state 232 self.assertAllClose(values[-1], 233 max_length * np.ones( 234 (batch_size, input_size), dtype=np.float32)) 235 236 @test_util.run_v1_only("b/124229375") 237 def testDropout(self): 238 cell = Plus1RNNCell() 239 full_dropout_cell = rnn_cell.DropoutWrapper( 240 cell, input_keep_prob=1e-6, seed=0) 241 self.assertIn("cell", full_dropout_cell._trackable_children()) 242 self.assertIs(full_dropout_cell._trackable_children()["cell"], cell) 243 batch_size = 2 244 input_size = 5 245 max_length = 8 246 inputs = max_length * [ 247 array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size)) 248 ] 249 with variable_scope.variable_scope("share_scope"): 250 outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32) 251 with variable_scope.variable_scope("drop_scope"): 252 dropped_outputs, _ = rnn.static_rnn( 253 full_dropout_cell, inputs, dtype=dtypes.float32) 254 self.assertEqual(len(outputs), len(inputs)) 255 for out, inp in zip(outputs, inputs): 256 self.assertEqual(out.get_shape().as_list(), inp.get_shape().as_list()) 257 self.assertEqual(out.dtype, inp.dtype) 258 259 with self.session() as sess: 260 input_value = np.random.randn(batch_size, input_size) 261 values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value}) 262 full_dropout_values = sess.run( 263 dropped_outputs, feed_dict={ 264 inputs[0]: input_value 265 }) 266 267 for v in values[:-1]: 268 self.assertAllClose(v, input_value + 1.0) 269 for d_v in full_dropout_values[:-1]: # Add 1.0 to dropped_out (all zeros) 270 self.assertAllClose(d_v, np.ones_like(input_value)) 271 272 @test_util.run_v1_only("b/124229375") 273 def testDynamicCalculation(self): 274 cell = Plus1RNNCell() 275 sequence_length = array_ops.placeholder(dtypes.int64) 276 batch_size = 2 277 input_size = 5 278 max_length = 8 279 inputs = max_length * [ 280 array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size)) 281 ] 282 with variable_scope.variable_scope("drop_scope"): 283 dynamic_outputs, dynamic_state = rnn.static_rnn( 284 cell, inputs, sequence_length=sequence_length, dtype=dtypes.float32) 285 self.assertEqual(len(dynamic_outputs), len(inputs)) 286 287 with self.session() as sess: 288 input_value = np.random.randn(batch_size, input_size) 289 dynamic_values = sess.run( 290 dynamic_outputs, 291 feed_dict={ 292 inputs[0]: input_value, 293 sequence_length: [2, 3] 294 }) 295 dynamic_state_value = sess.run( 296 [dynamic_state], 297 feed_dict={ 298 inputs[0]: input_value, 299 sequence_length: [2, 3] 300 }) 301 302 # outputs are fully calculated for t = 0, 1 303 for v in dynamic_values[:2]: 304 self.assertAllClose(v, input_value + 1.0) 305 306 # outputs at t = 2 are zero for entry 0, calculated for entry 1 307 self.assertAllClose(dynamic_values[2], 308 np.vstack((np.zeros((input_size)), 309 1.0 + input_value[1, :]))) 310 311 # outputs at t = 3+ are zero 312 for v in dynamic_values[3:]: 313 self.assertAllEqual(v, np.zeros_like(input_value)) 314 315 # the final states are: 316 # entry 0: the values from the calculation at t=1 317 # entry 1: the values from the calculation at t=2 318 self.assertAllEqual(dynamic_state_value[0], 319 np.vstack((1.0 * (1 + 1) * np.ones((input_size)), 320 1.0 * (2 + 1) * np.ones((input_size))))) 321 322 def _testScope(self, factory, prefix="prefix", use_outer_scope=True): 323 with self.session(graph=ops.Graph()): 324 if use_outer_scope: 325 with variable_scope.variable_scope(prefix) as scope: 326 factory(scope) 327 else: 328 factory(prefix) 329 330 # check that all the variables names starts 331 # with the proper scope. 332 variables_lib.global_variables_initializer() 333 all_vars = variables_lib.global_variables() 334 prefix = prefix or "rnn" 335 scope_vars = [v for v in all_vars if v.name.startswith(prefix + "/")] 336 tf_logging.info("RNN with scope: %s (%s)" % 337 (prefix, "scope" if use_outer_scope else "str")) 338 for v in scope_vars: 339 tf_logging.info(v.name) 340 self.assertEqual(len(scope_vars), len(all_vars)) 341 342 @test_util.run_v1_only("b/124229375") 343 def testScope(self): 344 345 def factory(scope): 346 cell = Plus1RNNCell() 347 batch_size = 2 348 input_size = 5 349 max_length = 8 # unrolled up to this length 350 inputs = max_length * [ 351 array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size)) 352 ] 353 return rnn.static_rnn(cell, inputs, dtype=dtypes.float32, scope=scope) 354 355 self._testScope(factory, use_outer_scope=True) 356 self._testScope(factory, use_outer_scope=False) 357 self._testScope(factory, prefix=None, use_outer_scope=False) 358 359 360class LSTMTest(test.TestCase): 361 362 def setUp(self): 363 self._seed = 23489 364 np.random.seed(self._seed) 365 366 def testDType(self): 367 # Test case for GitHub issue 16228 368 # Not passing dtype in constructor results in default float32 369 lstm = rnn_cell.LSTMCell(10) 370 input_tensor = array_ops.ones([10, 50]) 371 lstm.build(input_tensor.get_shape()) 372 self.assertEqual(lstm._bias.dtype.base_dtype, dtypes.float32) 373 374 # Explicitly pass dtype in constructor 375 for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]: 376 lstm = rnn_cell.LSTMCell(10, dtype=dtype) 377 input_tensor = array_ops.ones([10, 50]) 378 lstm.build(input_tensor.get_shape()) 379 self.assertEqual(lstm._bias.dtype.base_dtype, dtype) 380 381 @test_util.run_v1_only("b/124229375") 382 def testNoProjNoSharding(self): 383 num_units = 3 384 input_size = 5 385 batch_size = 2 386 max_length = 8 387 with self.session(graph=ops.Graph()) as sess: 388 initializer = init_ops.random_uniform_initializer( 389 -0.01, 0.01, seed=self._seed) 390 cell = rnn_cell.LSTMCell( 391 num_units, initializer=initializer, state_is_tuple=False) 392 inputs = max_length * [ 393 array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size)) 394 ] 395 outputs, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32) 396 self.assertEqual(len(outputs), len(inputs)) 397 for out in outputs: 398 self.assertEqual(out.get_shape().as_list(), [batch_size, num_units]) 399 400 variables_lib.global_variables_initializer().run() 401 input_value = np.random.randn(batch_size, input_size) 402 sess.run(outputs, feed_dict={inputs[0]: input_value}) 403 404 @test_util.run_v1_only("b/124229375") 405 def testCellClipping(self): 406 num_units = 3 407 input_size = 5 408 batch_size = 2 409 max_length = 8 410 with self.session(graph=ops.Graph()) as sess: 411 initializer = init_ops.random_uniform_initializer( 412 -0.01, 0.01, seed=self._seed) 413 cell = rnn_cell.LSTMCell( 414 num_units, 415 use_peepholes=True, 416 cell_clip=0.0, 417 initializer=initializer, 418 state_is_tuple=False) 419 inputs = max_length * [ 420 array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size)) 421 ] 422 outputs, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32) 423 self.assertEqual(len(outputs), len(inputs)) 424 for out in outputs: 425 self.assertEqual(out.get_shape().as_list(), [batch_size, num_units]) 426 427 variables_lib.global_variables_initializer().run() 428 input_value = np.random.randn(batch_size, input_size) 429 values = sess.run(outputs, feed_dict={inputs[0]: input_value}) 430 431 for value in values: 432 # if cell c is clipped to 0, tanh(c) = 0 => m==0 433 self.assertAllEqual(value, np.zeros((batch_size, num_units))) 434 435 @test_util.run_v1_only("b/124229375") 436 def testNoProjNoShardingSimpleStateSaver(self): 437 num_units = 3 438 input_size = 5 439 batch_size = 2 440 max_length = 8 441 with self.session(graph=ops.Graph()) as sess: 442 initializer = init_ops.random_uniform_initializer( 443 -0.01, 0.01, seed=self._seed) 444 state_saver = TestStateSaver(batch_size, 2 * num_units) 445 cell = rnn_cell.LSTMCell( 446 num_units, 447 use_peepholes=False, 448 initializer=initializer, 449 state_is_tuple=False) 450 inputs = max_length * [ 451 array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size)) 452 ] 453 with variable_scope.variable_scope("share_scope"): 454 outputs, state = rnn.static_state_saving_rnn( 455 cell, inputs, state_saver=state_saver, state_name="save_lstm") 456 self.assertEqual(len(outputs), len(inputs)) 457 for out in outputs: 458 self.assertEqual(out.get_shape().as_list(), [batch_size, num_units]) 459 460 variables_lib.global_variables_initializer().run() 461 input_value = np.random.randn(batch_size, input_size) 462 (last_state_value, saved_state_value) = sess.run( 463 [state, state_saver.saved_state["save_lstm"]], 464 feed_dict={ 465 inputs[0]: input_value 466 }) 467 self.assertAllEqual(last_state_value, saved_state_value) 468 469 @test_util.run_v1_only("b/124229375") 470 def testNoProjNoShardingTupleStateSaver(self): 471 num_units = 3 472 input_size = 5 473 batch_size = 2 474 max_length = 8 475 with self.session(graph=ops.Graph()) as sess: 476 initializer = init_ops.random_uniform_initializer( 477 -0.01, 0.01, seed=self._seed) 478 state_saver = TestStateSaver(batch_size, num_units) 479 cell = rnn_cell.LSTMCell( 480 num_units, 481 use_peepholes=False, 482 initializer=initializer, 483 state_is_tuple=True) 484 inputs = max_length * [ 485 array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size)) 486 ] 487 with variable_scope.variable_scope("share_scope"): 488 outputs, state = rnn.static_state_saving_rnn( 489 cell, inputs, state_saver=state_saver, state_name=("c", "m")) 490 self.assertEqual(len(outputs), len(inputs)) 491 for out in outputs: 492 self.assertEqual(out.get_shape().as_list(), [batch_size, num_units]) 493 494 variables_lib.global_variables_initializer().run() 495 input_value = np.random.randn(batch_size, input_size) 496 last_and_saved_states = sess.run( 497 state + (state_saver.saved_state["c"], state_saver.saved_state["m"]), 498 feed_dict={ 499 inputs[0]: input_value 500 }) 501 self.assertEqual(4, len(last_and_saved_states)) 502 self.assertAllEqual(last_and_saved_states[:2], last_and_saved_states[2:]) 503 504 @test_util.run_v1_only("b/124229375") 505 def testNoProjNoShardingNestedTupleStateSaver(self): 506 num_units = 3 507 input_size = 5 508 batch_size = 2 509 max_length = 8 510 with self.session(graph=ops.Graph()) as sess: 511 initializer = init_ops.random_uniform_initializer( 512 -0.01, 0.01, seed=self._seed) 513 state_saver = TestStateSaver( 514 batch_size, { 515 "c0": num_units, 516 "m0": num_units, 517 "c1": num_units + 1, 518 "m1": num_units + 1, 519 "c2": num_units + 2, 520 "m2": num_units + 2, 521 "c3": num_units + 3, 522 "m3": num_units + 3 523 }) 524 525 def _cell(i): 526 return rnn_cell.LSTMCell( 527 num_units + i, 528 use_peepholes=False, 529 initializer=initializer, 530 state_is_tuple=True) 531 532 # This creates a state tuple which has 4 sub-tuples of length 2 each. 533 cell = rnn_cell.MultiRNNCell( 534 [_cell(i) for i in range(4)], state_is_tuple=True) 535 536 self.assertEqual(len(cell.state_size), 4) 537 for i in range(4): 538 self.assertEqual(len(cell.state_size[i]), 2) 539 540 inputs = max_length * [ 541 array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size)) 542 ] 543 544 state_names = (("c0", "m0"), ("c1", "m1"), ("c2", "m2"), ("c3", "m3")) 545 with variable_scope.variable_scope("share_scope"): 546 outputs, state = rnn.static_state_saving_rnn( 547 cell, inputs, state_saver=state_saver, state_name=state_names) 548 self.assertEqual(len(outputs), len(inputs)) 549 550 # Final output comes from _cell(3) which has state size num_units + 3 551 for out in outputs: 552 self.assertEqual(out.get_shape().as_list(), [batch_size, num_units + 3]) 553 554 variables_lib.global_variables_initializer().run() 555 input_value = np.random.randn(batch_size, input_size) 556 last_states = sess.run( 557 list(nest.flatten(state)), feed_dict={ 558 inputs[0]: input_value 559 }) 560 saved_states = sess.run( 561 list(state_saver.saved_state.values()), 562 feed_dict={ 563 inputs[0]: input_value 564 }) 565 self.assertEqual(8, len(last_states)) 566 self.assertEqual(8, len(saved_states)) 567 flat_state_names = nest.flatten(state_names) 568 named_saved_states = dict( 569 zip(state_saver.saved_state.keys(), saved_states)) 570 571 for i in range(8): 572 self.assertAllEqual(last_states[i], 573 named_saved_states[flat_state_names[i]]) 574 575 @test_util.run_v1_only("b/124229375") 576 def testProjNoSharding(self): 577 num_units = 3 578 input_size = 5 579 batch_size = 2 580 num_proj = 4 581 max_length = 8 582 with self.session(graph=ops.Graph()) as sess: 583 initializer = init_ops.random_uniform_initializer( 584 -0.01, 0.01, seed=self._seed) 585 inputs = max_length * [ 586 array_ops.placeholder(dtypes.float32, shape=(None, input_size)) 587 ] 588 cell = rnn_cell.LSTMCell( 589 num_units, 590 use_peepholes=True, 591 num_proj=num_proj, 592 initializer=initializer, 593 state_is_tuple=False) 594 outputs, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32) 595 self.assertEqual(len(outputs), len(inputs)) 596 597 variables_lib.global_variables_initializer().run() 598 input_value = np.random.randn(batch_size, input_size) 599 sess.run(outputs, feed_dict={inputs[0]: input_value}) 600 601 def _testStateTupleWithProjAndSequenceLength(self): 602 num_units = 3 603 input_size = 5 604 batch_size = 2 605 num_proj = 4 606 max_length = 8 607 sequence_length = [4, 6] 608 with self.session(graph=ops.Graph()) as sess: 609 initializer = init_ops.random_uniform_initializer( 610 -0.01, 0.01, seed=self._seed) 611 inputs = max_length * [ 612 array_ops.placeholder(dtypes.float32, shape=(None, input_size)) 613 ] 614 cell_notuple = rnn_cell.LSTMCell( 615 num_units, 616 use_peepholes=True, 617 num_proj=num_proj, 618 initializer=initializer, 619 state_is_tuple=False) 620 cell_tuple = rnn_cell.LSTMCell( 621 num_units, 622 use_peepholes=True, 623 num_proj=num_proj, 624 initializer=initializer, 625 state_is_tuple=True) 626 with variable_scope.variable_scope("root") as scope: 627 outputs_notuple, state_notuple = rnn.static_rnn( 628 cell_notuple, 629 inputs, 630 dtype=dtypes.float32, 631 sequence_length=sequence_length, 632 scope=scope) 633 scope.reuse_variables() 634 # TODO(ebrevdo): For this test, we ensure values are identical and 635 # therefore the weights here are tied. In the future, we may consider 636 # making the state_is_tuple property mutable so we can avoid 637 # having to do this - especially if users ever need to reuse 638 # the parameters from different RNNCell instances. Right now, 639 # this seems an unrealistic use case except for testing. 640 cell_tuple._scope = cell_notuple._scope # pylint: disable=protected-access 641 outputs_tuple, state_tuple = rnn.static_rnn( 642 cell_tuple, 643 inputs, 644 dtype=dtypes.float32, 645 sequence_length=sequence_length, 646 scope=scope) 647 self.assertEqual(len(outputs_notuple), len(inputs)) 648 self.assertEqual(len(outputs_tuple), len(inputs)) 649 self.assertTrue(isinstance(state_tuple, tuple)) 650 self.assertTrue(isinstance(state_notuple, ops.Tensor)) 651 652 variables_lib.global_variables_initializer().run() 653 input_value = np.random.randn(batch_size, input_size) 654 outputs_notuple_v = sess.run( 655 outputs_notuple, feed_dict={ 656 inputs[0]: input_value 657 }) 658 outputs_tuple_v = sess.run( 659 outputs_tuple, feed_dict={ 660 inputs[0]: input_value 661 }) 662 self.assertAllEqual(outputs_notuple_v, outputs_tuple_v) 663 664 (state_notuple_v,) = sess.run( 665 (state_notuple,), feed_dict={ 666 inputs[0]: input_value 667 }) 668 state_tuple_v = sess.run(state_tuple, feed_dict={inputs[0]: input_value}) 669 self.assertAllEqual(state_notuple_v, np.hstack(state_tuple_v)) 670 671 @test_util.run_v1_only("b/124229375") 672 def testProjSharding(self): 673 num_units = 3 674 input_size = 5 675 batch_size = 2 676 num_proj = 4 677 num_proj_shards = 3 678 num_unit_shards = 2 679 max_length = 8 680 with self.session(graph=ops.Graph()) as sess: 681 initializer = init_ops.random_uniform_initializer( 682 -0.01, 0.01, seed=self._seed) 683 684 inputs = max_length * [ 685 array_ops.placeholder(dtypes.float32, shape=(None, input_size)) 686 ] 687 688 cell = rnn_cell.LSTMCell( 689 num_units, 690 use_peepholes=True, 691 num_proj=num_proj, 692 num_unit_shards=num_unit_shards, 693 num_proj_shards=num_proj_shards, 694 initializer=initializer, 695 state_is_tuple=False) 696 697 outputs, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32) 698 699 self.assertEqual(len(outputs), len(inputs)) 700 701 variables_lib.global_variables_initializer().run() 702 input_value = np.random.randn(batch_size, input_size) 703 sess.run(outputs, feed_dict={inputs[0]: input_value}) 704 705 @test_util.run_v1_only("b/124229375") 706 def testDoubleInput(self): 707 num_units = 3 708 input_size = 5 709 batch_size = 2 710 num_proj = 4 711 num_proj_shards = 3 712 num_unit_shards = 2 713 max_length = 8 714 with self.session(graph=ops.Graph()) as sess: 715 initializer = init_ops.random_uniform_initializer(-1, 1, seed=self._seed) 716 inputs = max_length * [ 717 array_ops.placeholder(dtypes.float64, shape=(None, input_size)) 718 ] 719 720 cell = rnn_cell.LSTMCell( 721 num_units, 722 use_peepholes=True, 723 num_proj=num_proj, 724 num_unit_shards=num_unit_shards, 725 num_proj_shards=num_proj_shards, 726 initializer=initializer, 727 state_is_tuple=False) 728 729 outputs, _ = rnn.static_rnn( 730 cell, 731 inputs, 732 initial_state=cell.zero_state(batch_size, dtypes.float64)) 733 734 self.assertEqual(len(outputs), len(inputs)) 735 736 variables_lib.global_variables_initializer().run() 737 input_value = np.asarray( 738 np.random.randn(batch_size, input_size), dtype=np.float64) 739 values = sess.run(outputs, feed_dict={inputs[0]: input_value}) 740 self.assertEqual(values[0].dtype, input_value.dtype) 741 742 @test_util.run_v1_only("b/124229375") 743 def testShardNoShardEquivalentOutput(self): 744 num_units = 3 745 input_size = 5 746 batch_size = 2 747 num_proj = 4 748 num_proj_shards = 3 749 num_unit_shards = 2 750 max_length = 8 751 with self.session(graph=ops.Graph()) as sess: 752 inputs = max_length * [ 753 array_ops.placeholder(dtypes.float32, shape=(None, input_size)) 754 ] 755 initializer = init_ops.constant_initializer(0.001) 756 757 cell_noshard = rnn_cell.LSTMCell( 758 num_units, 759 num_proj=num_proj, 760 use_peepholes=True, 761 initializer=initializer, 762 num_unit_shards=num_unit_shards, 763 num_proj_shards=num_proj_shards, 764 state_is_tuple=False) 765 766 cell_shard = rnn_cell.LSTMCell( 767 num_units, 768 use_peepholes=True, 769 initializer=initializer, 770 num_proj=num_proj, 771 state_is_tuple=False) 772 773 with variable_scope.variable_scope("noshard_scope"): 774 outputs_noshard, state_noshard = rnn.static_rnn( 775 cell_noshard, inputs, dtype=dtypes.float32) 776 with variable_scope.variable_scope("shard_scope"): 777 outputs_shard, state_shard = rnn.static_rnn( 778 cell_shard, inputs, dtype=dtypes.float32) 779 780 self.assertEqual(len(outputs_noshard), len(inputs)) 781 self.assertEqual(len(outputs_noshard), len(outputs_shard)) 782 783 variables_lib.global_variables_initializer().run() 784 input_value = np.random.randn(batch_size, input_size) 785 feeds = dict((x, input_value) for x in inputs) 786 values_noshard = sess.run(outputs_noshard, feed_dict=feeds) 787 values_shard = sess.run(outputs_shard, feed_dict=feeds) 788 state_values_noshard = sess.run([state_noshard], feed_dict=feeds) 789 state_values_shard = sess.run([state_shard], feed_dict=feeds) 790 self.assertEqual(len(values_noshard), len(values_shard)) 791 self.assertEqual(len(state_values_noshard), len(state_values_shard)) 792 for (v_noshard, v_shard) in zip(values_noshard, values_shard): 793 self.assertAllClose(v_noshard, v_shard, atol=1e-3) 794 for (s_noshard, s_shard) in zip(state_values_noshard, state_values_shard): 795 self.assertAllClose(s_noshard, s_shard, atol=1e-3) 796 797 @test_util.run_v1_only("b/124229375") 798 def testDoubleInputWithDropoutAndDynamicCalculation(self): 799 """Smoke test for using LSTM with doubles, dropout, dynamic calculation.""" 800 801 num_units = 3 802 input_size = 5 803 batch_size = 2 804 num_proj = 4 805 num_proj_shards = 3 806 num_unit_shards = 2 807 max_length = 8 808 with self.session(graph=ops.Graph()) as sess: 809 sequence_length = array_ops.placeholder(dtypes.int64) 810 initializer = init_ops.random_uniform_initializer( 811 -0.01, 0.01, seed=self._seed) 812 inputs = max_length * [ 813 array_ops.placeholder(dtypes.float64, shape=(None, input_size)) 814 ] 815 816 cell = rnn_cell.LSTMCell( 817 num_units, 818 use_peepholes=True, 819 num_proj=num_proj, 820 num_unit_shards=num_unit_shards, 821 num_proj_shards=num_proj_shards, 822 initializer=initializer, 823 state_is_tuple=False) 824 dropout_cell = rnn_cell.DropoutWrapper(cell, 0.5, seed=0) 825 826 outputs, state = rnn.static_rnn( 827 dropout_cell, 828 inputs, 829 sequence_length=sequence_length, 830 initial_state=cell.zero_state(batch_size, dtypes.float64)) 831 832 self.assertEqual(len(outputs), len(inputs)) 833 834 variables_lib.global_variables_initializer().run(feed_dict={ 835 sequence_length: [2, 3] 836 }) 837 input_value = np.asarray( 838 np.random.randn(batch_size, input_size), dtype=np.float64) 839 values = sess.run( 840 outputs, feed_dict={ 841 inputs[0]: input_value, 842 sequence_length: [2, 3] 843 }) 844 state_value = sess.run( 845 [state], feed_dict={ 846 inputs[0]: input_value, 847 sequence_length: [2, 3] 848 }) 849 self.assertEqual(values[0].dtype, input_value.dtype) 850 self.assertEqual(state_value[0].dtype, input_value.dtype) 851 852 @test_util.run_v1_only("b/124229375") 853 def testSharingWeightsWithReuse(self): 854 num_units = 3 855 input_size = 5 856 batch_size = 2 857 num_proj = 4 858 max_length = 8 859 with self.session(graph=ops.Graph()) as sess: 860 initializer = init_ops.random_uniform_initializer(-1, 1, seed=self._seed) 861 initializer_d = init_ops.random_uniform_initializer( 862 -1, 1, seed=self._seed + 1) 863 inputs = max_length * [ 864 array_ops.placeholder(dtypes.float32, shape=(None, input_size)) 865 ] 866 cell = rnn_cell.LSTMCell( 867 num_units, 868 use_peepholes=True, 869 num_proj=num_proj, 870 initializer=initializer, 871 state_is_tuple=False) 872 cell_d = rnn_cell.LSTMCell( 873 num_units, 874 use_peepholes=True, 875 num_proj=num_proj, 876 initializer=initializer_d, 877 state_is_tuple=False) 878 879 with variable_scope.variable_scope("share_scope"): 880 outputs0, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32) 881 with variable_scope.variable_scope("share_scope", reuse=True): 882 outputs1, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32) 883 with variable_scope.variable_scope("diff_scope"): 884 outputs2, _ = rnn.static_rnn(cell_d, inputs, dtype=dtypes.float32) 885 886 variables_lib.global_variables_initializer().run() 887 input_value = np.random.randn(batch_size, input_size) 888 output_values = sess.run( 889 outputs0 + outputs1 + outputs2, feed_dict={ 890 inputs[0]: input_value 891 }) 892 outputs0_values = output_values[:max_length] 893 outputs1_values = output_values[max_length:2 * max_length] 894 outputs2_values = output_values[2 * max_length:] 895 self.assertEqual(len(outputs0_values), len(outputs1_values)) 896 self.assertEqual(len(outputs0_values), len(outputs2_values)) 897 for o1, o2, o3 in zip(outputs0_values, outputs1_values, outputs2_values): 898 # Same weights used by both RNNs so outputs should be the same. 899 self.assertAllEqual(o1, o2) 900 # Different weights used so outputs should be different. 901 self.assertTrue(np.linalg.norm(o1 - o3) > 1e-6) 902 903 @test_util.run_v1_only("b/124229375") 904 def testSharingWeightsWithDifferentNamescope(self): 905 num_units = 3 906 input_size = 5 907 batch_size = 2 908 num_proj = 4 909 max_length = 8 910 with self.session(graph=ops.Graph()) as sess: 911 initializer = init_ops.random_uniform_initializer(-1, 1, seed=self._seed) 912 inputs = max_length * [ 913 array_ops.placeholder(dtypes.float32, shape=(None, input_size)) 914 ] 915 cell = rnn_cell.LSTMCell( 916 num_units, 917 use_peepholes=True, 918 num_proj=num_proj, 919 initializer=initializer, 920 state_is_tuple=False) 921 922 with ops.name_scope("scope0"): 923 with variable_scope.variable_scope("share_scope"): 924 outputs0, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32) 925 with ops.name_scope("scope1"): 926 with variable_scope.variable_scope("share_scope", reuse=True): 927 outputs1, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32) 928 929 variables_lib.global_variables_initializer().run() 930 input_value = np.random.randn(batch_size, input_size) 931 output_values = sess.run( 932 outputs0 + outputs1, feed_dict={ 933 inputs[0]: input_value 934 }) 935 outputs0_values = output_values[:max_length] 936 outputs1_values = output_values[max_length:] 937 self.assertEqual(len(outputs0_values), len(outputs1_values)) 938 for out0, out1 in zip(outputs0_values, outputs1_values): 939 self.assertAllEqual(out0, out1) 940 941 @test_util.run_v1_only("b/124229375") 942 def testDynamicRNNAllowsUnknownTimeDimension(self): 943 inputs = array_ops.placeholder(dtypes.float32, shape=[1, None, 20]) 944 cell = rnn_cell.GRUCell(30) 945 # Smoke test, this should not raise an error 946 rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32) 947 948 @test_util.run_in_graph_and_eager_modes 949 def testDynamicRNNWithTupleStates(self): 950 num_units = 3 951 input_size = 5 952 batch_size = 2 953 num_proj = 4 954 max_length = 8 955 sequence_length = [4, 6] 956 in_graph_mode = not context.executing_eagerly() 957 with self.session(graph=ops.Graph()) as sess: 958 initializer = init_ops.random_uniform_initializer( 959 -0.01, 0.01, seed=self._seed) 960 if in_graph_mode: 961 inputs = max_length * [ 962 array_ops.placeholder(dtypes.float32, shape=(None, input_size)) 963 ] 964 else: 965 inputs = max_length * [ 966 constant_op.constant( 967 np.random.randn(batch_size, input_size).astype(np.float32)) 968 ] 969 inputs_c = array_ops.stack(inputs) 970 cell = rnn_cell.LSTMCell( 971 num_units, 972 use_peepholes=True, 973 num_proj=num_proj, 974 initializer=initializer, 975 state_is_tuple=True) 976 with variable_scope.variable_scope("root") as scope: 977 outputs_static, state_static = rnn.static_rnn( 978 cell, 979 inputs, 980 dtype=dtypes.float32, 981 sequence_length=sequence_length, 982 scope=scope) 983 scope.reuse_variables() 984 outputs_dynamic, state_dynamic = rnn.dynamic_rnn( 985 cell, 986 inputs_c, 987 dtype=dtypes.float32, 988 time_major=True, 989 sequence_length=sequence_length, 990 scope=scope) 991 self.assertTrue(isinstance(state_static, rnn_cell.LSTMStateTuple)) 992 self.assertTrue(isinstance(state_dynamic, rnn_cell.LSTMStateTuple)) 993 self.assertIs(state_static[0], state_static.c) 994 self.assertIs(state_static[1], state_static.h) 995 self.assertIs(state_dynamic[0], state_dynamic.c) 996 self.assertIs(state_dynamic[1], state_dynamic.h) 997 998 if in_graph_mode: 999 variables_lib.global_variables_initializer().run() 1000 input_value = np.random.randn(batch_size, input_size) 1001 outputs_static = sess.run( 1002 outputs_static, feed_dict={ 1003 inputs[0]: input_value 1004 }) 1005 outputs_dynamic = sess.run( 1006 outputs_dynamic, feed_dict={ 1007 inputs[0]: input_value 1008 }) 1009 state_static = sess.run( 1010 state_static, feed_dict={ 1011 inputs[0]: input_value 1012 }) 1013 state_dynamic = sess.run( 1014 state_dynamic, feed_dict={ 1015 inputs[0]: input_value 1016 }) 1017 1018 comparison_fn = self.assertAllEqual 1019 if test_util.is_xla_enabled(): 1020 comparison_fn = self.assertAllClose 1021 if in_graph_mode: 1022 comparison_fn(outputs_static, outputs_dynamic) 1023 else: 1024 self.assertAllEqual(array_ops.stack(outputs_static), outputs_dynamic) 1025 comparison_fn(np.hstack(state_static), np.hstack(state_dynamic)) 1026 1027 @test_util.run_in_graph_and_eager_modes 1028 def testDynamicRNNWithNestedTupleStates(self): 1029 num_units = 3 1030 input_size = 5 1031 batch_size = 2 1032 num_proj = 4 1033 max_length = 8 1034 sequence_length = [4, 6] 1035 in_graph_mode = not context.executing_eagerly() 1036 with self.session(graph=ops.Graph()) as sess: 1037 initializer = init_ops.random_uniform_initializer( 1038 -0.01, 0.01, seed=self._seed) 1039 if in_graph_mode: 1040 inputs = max_length * [ 1041 array_ops.placeholder(dtypes.float32, shape=(None, input_size)) 1042 ] 1043 else: 1044 inputs = max_length * [ 1045 constant_op.constant( 1046 np.random.randn(batch_size, input_size).astype(np.float32)) 1047 ] 1048 inputs_c = array_ops.stack(inputs) 1049 1050 def _cell(i): 1051 return rnn_cell.LSTMCell( 1052 num_units + i, 1053 use_peepholes=True, 1054 num_proj=num_proj + i, 1055 initializer=initializer, 1056 state_is_tuple=True) 1057 1058 # This creates a state tuple which has 4 sub-tuples of length 2 each. 1059 cell = rnn_cell.MultiRNNCell( 1060 [_cell(i) for i in range(4)], state_is_tuple=True) 1061 1062 self.assertEqual(len(cell.state_size), 4) 1063 for i in range(4): 1064 self.assertEqual(len(cell.state_size[i]), 2) 1065 1066 test_zero = cell.zero_state(1, dtypes.float32) 1067 self.assertEqual(len(test_zero), 4) 1068 for i in range(4): 1069 self.assertEqual(test_zero[i][0].get_shape()[1], cell.state_size[i][0]) 1070 self.assertEqual(test_zero[i][1].get_shape()[1], cell.state_size[i][1]) 1071 1072 with variable_scope.variable_scope("root") as scope: 1073 outputs_static, state_static = rnn.static_rnn( 1074 cell, 1075 inputs, 1076 dtype=dtypes.float32, 1077 sequence_length=sequence_length, 1078 scope=scope) 1079 scope.reuse_variables() 1080 outputs_dynamic, state_dynamic = rnn.dynamic_rnn( 1081 cell, 1082 inputs_c, 1083 dtype=dtypes.float32, 1084 time_major=True, 1085 sequence_length=sequence_length, 1086 scope=scope) 1087 1088 if in_graph_mode: 1089 input_value = np.random.randn(batch_size, input_size) 1090 variables_lib.global_variables_initializer().run() 1091 outputs_static = sess.run( 1092 outputs_static, feed_dict={ 1093 inputs[0]: input_value 1094 }) 1095 outputs_dynamic = sess.run( 1096 outputs_dynamic, feed_dict={ 1097 inputs[0]: input_value 1098 }) 1099 state_static = sess.run( 1100 nest.flatten(state_static), feed_dict={ 1101 inputs[0]: input_value 1102 }) 1103 state_dynamic = sess.run( 1104 nest.flatten(state_dynamic), feed_dict={ 1105 inputs[0]: input_value 1106 }) 1107 1108 comparison_fn = self.assertAllEqual 1109 if test_util.is_xla_enabled(): 1110 comparison_fn = self.assertAllClose 1111 if in_graph_mode: 1112 comparison_fn(outputs_static, outputs_dynamic) 1113 else: 1114 self.assertAllEqual(array_ops.stack(outputs_static), outputs_dynamic) 1115 state_static = nest.flatten(state_static) 1116 state_dynamic = nest.flatten(state_dynamic) 1117 comparison_fn(np.hstack(state_static), np.hstack(state_dynamic)) 1118 1119 def _testDynamicEquivalentToStaticRNN(self, use_sequence_length): 1120 time_steps = 8 1121 num_units = 3 1122 num_proj = 4 1123 input_size = 5 1124 batch_size = 2 1125 1126 input_values = np.random.randn(time_steps, batch_size, input_size).astype( 1127 np.float32) 1128 1129 if use_sequence_length: 1130 sequence_length = np.random.randint(0, time_steps, size=batch_size) 1131 else: 1132 sequence_length = None 1133 1134 in_graph_mode = not context.executing_eagerly() 1135 1136 # TODO(b/68017812): Eager ignores operation seeds, so we need to create a 1137 # single cell and reuse it across the static and dynamic RNNs. Remove this 1138 # special case once is fixed. 1139 if not in_graph_mode: 1140 initializer = init_ops.random_uniform_initializer( 1141 -0.01, 0.01, seed=self._seed) 1142 cell = rnn_cell.LSTMCell( 1143 num_units, 1144 use_peepholes=True, 1145 initializer=initializer, 1146 num_proj=num_proj, 1147 state_is_tuple=False) 1148 1149 ########### Step 1: Run static graph and generate readouts 1150 with self.session(graph=ops.Graph()) as sess: 1151 if in_graph_mode: 1152 concat_inputs = array_ops.placeholder( 1153 dtypes.float32, shape=(time_steps, batch_size, input_size)) 1154 else: 1155 concat_inputs = constant_op.constant(input_values) 1156 inputs = array_ops.unstack(concat_inputs) 1157 initializer = init_ops.random_uniform_initializer( 1158 -0.01, 0.01, seed=self._seed) 1159 1160 # TODO(akshayka): Remove special case once b/68017812 is fixed. 1161 if in_graph_mode: 1162 cell = rnn_cell.LSTMCell( 1163 num_units, 1164 use_peepholes=True, 1165 initializer=initializer, 1166 num_proj=num_proj, 1167 state_is_tuple=False) 1168 1169 with variable_scope.variable_scope("dynamic_scope"): 1170 outputs_static, state_static = rnn.static_rnn( 1171 cell, inputs, sequence_length=sequence_length, dtype=dtypes.float32) 1172 1173 if in_graph_mode: 1174 # Generate gradients of sum of outputs w.r.t. inputs 1175 static_gradients = gradients_impl.gradients( 1176 outputs_static + [state_static], [concat_inputs]) 1177 # Generate gradients of individual outputs w.r.t. inputs 1178 static_individual_gradients = nest.flatten([ 1179 gradients_impl.gradients(y, [concat_inputs]) 1180 for y in [outputs_static[0], outputs_static[-1], state_static] 1181 ]) 1182 # Generate gradients of individual variables w.r.t. inputs 1183 trainable_variables = ops.get_collection( 1184 ops.GraphKeys.TRAINABLE_VARIABLES) 1185 assert len(trainable_variables) > 1, ( 1186 "Count of trainable variables: %d" % len(trainable_variables)) 1187 # pylint: disable=bad-builtin 1188 static_individual_variable_gradients = nest.flatten([ 1189 gradients_impl.gradients(y, trainable_variables) 1190 for y in [outputs_static[0], outputs_static[-1], state_static] 1191 ]) 1192 # Generate gradients and run sessions to obtain outputs 1193 feeds = {concat_inputs: input_values} 1194 # Initialize 1195 variables_lib.global_variables_initializer().run(feed_dict=feeds) 1196 # Test forward pass 1197 values_static = sess.run(outputs_static, feed_dict=feeds) 1198 (state_value_static,) = sess.run((state_static,), feed_dict=feeds) 1199 1200 # Test gradients to inputs and variables w.r.t. outputs & final state 1201 static_grad_values = sess.run(static_gradients, feed_dict=feeds) 1202 1203 static_individual_grad_values = sess.run( 1204 static_individual_gradients, feed_dict=feeds) 1205 1206 static_individual_var_grad_values = sess.run( 1207 static_individual_variable_gradients, feed_dict=feeds) 1208 1209 ########## Step 2: Run dynamic graph and generate readouts 1210 with self.session(graph=ops.Graph()) as sess: 1211 if in_graph_mode: 1212 concat_inputs = array_ops.placeholder( 1213 dtypes.float32, shape=(time_steps, batch_size, input_size)) 1214 else: 1215 concat_inputs = constant_op.constant(input_values) 1216 initializer = init_ops.random_uniform_initializer( 1217 -0.01, 0.01, seed=self._seed) 1218 1219 # TODO(akshayka): Remove this special case once b/68017812 is 1220 # fixed. 1221 if in_graph_mode: 1222 cell = rnn_cell.LSTMCell( 1223 num_units, 1224 use_peepholes=True, 1225 initializer=initializer, 1226 num_proj=num_proj, 1227 state_is_tuple=False) 1228 1229 with variable_scope.variable_scope("dynamic_scope"): 1230 outputs_dynamic, state_dynamic = rnn.dynamic_rnn( 1231 cell, 1232 inputs=concat_inputs, 1233 sequence_length=sequence_length, 1234 time_major=True, 1235 dtype=dtypes.float32) 1236 split_outputs_dynamic = array_ops.unstack(outputs_dynamic, time_steps) 1237 1238 if in_graph_mode: 1239 1240 # Generate gradients of sum of outputs w.r.t. inputs 1241 dynamic_gradients = gradients_impl.gradients( 1242 split_outputs_dynamic + [state_dynamic], [concat_inputs]) 1243 1244 # Generate gradients of several individual outputs w.r.t. inputs 1245 dynamic_individual_gradients = nest.flatten([ 1246 gradients_impl.gradients(y, [concat_inputs]) 1247 for y in [ 1248 split_outputs_dynamic[0], split_outputs_dynamic[-1], 1249 state_dynamic 1250 ] 1251 ]) 1252 1253 # Generate gradients of individual variables w.r.t. inputs 1254 trainable_variables = ops.get_collection( 1255 ops.GraphKeys.TRAINABLE_VARIABLES) 1256 assert len(trainable_variables) > 1, ( 1257 "Count of trainable variables: %d" % len(trainable_variables)) 1258 dynamic_individual_variable_gradients = nest.flatten([ 1259 gradients_impl.gradients(y, trainable_variables) 1260 for y in [ 1261 split_outputs_dynamic[0], split_outputs_dynamic[-1], 1262 state_dynamic 1263 ] 1264 ]) 1265 1266 feeds = {concat_inputs: input_values} 1267 1268 # Initialize 1269 variables_lib.global_variables_initializer().run(feed_dict=feeds) 1270 1271 # Test forward pass 1272 values_dynamic = sess.run(split_outputs_dynamic, feed_dict=feeds) 1273 (state_value_dynamic,) = sess.run((state_dynamic,), feed_dict=feeds) 1274 1275 # Test gradients to inputs and variables w.r.t. outputs & final state 1276 dynamic_grad_values = sess.run(dynamic_gradients, feed_dict=feeds) 1277 1278 dynamic_individual_grad_values = sess.run( 1279 dynamic_individual_gradients, feed_dict=feeds) 1280 1281 dynamic_individual_var_grad_values = sess.run( 1282 dynamic_individual_variable_gradients, feed_dict=feeds) 1283 1284 ######### Step 3: Comparisons 1285 if not in_graph_mode: 1286 values_static = outputs_static 1287 values_dynamic = split_outputs_dynamic 1288 state_value_static = state_static 1289 state_value_dynamic = state_dynamic 1290 1291 self.assertEqual(len(values_static), len(values_dynamic)) 1292 for (value_static, value_dynamic) in zip(values_static, values_dynamic): 1293 self.assertAllClose(value_static, value_dynamic) 1294 self.assertAllClose(state_value_static, state_value_dynamic) 1295 1296 if in_graph_mode: 1297 1298 self.assertAllClose(static_grad_values, dynamic_grad_values) 1299 1300 self.assertEqual( 1301 len(static_individual_grad_values), 1302 len(dynamic_individual_grad_values)) 1303 self.assertEqual( 1304 len(static_individual_var_grad_values), 1305 len(dynamic_individual_var_grad_values)) 1306 1307 for i, (a, b) in enumerate( 1308 zip(static_individual_grad_values, dynamic_individual_grad_values)): 1309 tf_logging.info("Comparing individual gradients iteration %d" % i) 1310 self.assertAllClose(a, b) 1311 1312 for i, (a, b) in enumerate( 1313 zip(static_individual_var_grad_values, 1314 dynamic_individual_var_grad_values)): 1315 tf_logging.info( 1316 "Comparing individual variable gradients iteration %d" % i) 1317 self.assertAllClose(a, b) 1318 1319 @test_util.run_in_graph_and_eager_modes 1320 def testDynamicEquivalentToStaticRNN(self): 1321 self._testDynamicEquivalentToStaticRNN(use_sequence_length=False) 1322 1323 @test_util.run_in_graph_and_eager_modes 1324 def testDynamicEquivalentToStaticRNNWithSequenceLength(self): 1325 self._testDynamicEquivalentToStaticRNN(use_sequence_length=True) 1326 1327 @test_util.run_in_graph_and_eager_modes 1328 def testLSTMBlockCellErrorHandling(self): 1329 forget_bias = 1 1330 cell_clip = 0 1331 use_peephole = False 1332 x = constant_op.constant(0.837607, shape=[28, 29], dtype=dtypes.float32) 1333 cs_prev = constant_op.constant(0, shape=[28, 17], dtype=dtypes.float32) 1334 h_prev = constant_op.constant( 1335 0.592631638, shape=[28, 17], dtype=dtypes.float32) 1336 w = constant_op.constant(0.887386262, shape=[46, 68], dtype=dtypes.float32) 1337 wci = constant_op.constant(0, shape=[], dtype=dtypes.float32) 1338 wcf = constant_op.constant(0, shape=[17], dtype=dtypes.float32) 1339 wco = constant_op.constant( 1340 0.592631638, shape=[28, 17], dtype=dtypes.float32) 1341 b = constant_op.constant(0.75259006, shape=[68], dtype=dtypes.float32) 1342 with self.assertRaises(errors_impl.InvalidArgumentError): 1343 self.evaluate( 1344 gen_rnn_ops.lstm_block_cell( 1345 x=x, 1346 cs_prev=cs_prev, 1347 h_prev=h_prev, 1348 w=w, 1349 wci=wci, 1350 wcf=wcf, 1351 wco=wco, 1352 b=b, 1353 forget_bias=forget_bias, 1354 cell_clip=cell_clip, 1355 use_peephole=use_peephole)) 1356 1357 @test_util.run_in_graph_and_eager_modes 1358 def testLSTMBlockCellGradErrorHandling(self): 1359 use_peephole = False 1360 seq_len_max = constant_op.constant(1, shape=[], dtype=dtypes.int64) 1361 x = constant_op.constant(0.504355371, shape=[1, 1, 1], dtype=dtypes.float32) 1362 cs_prev = constant_op.constant( 1363 0.504355371, shape=[1, 1, 1], dtype=dtypes.float32) 1364 h_prev = constant_op.constant( 1365 0.504355371, shape=[1, 1], dtype=dtypes.float32) 1366 w = constant_op.constant(0.504355371, shape=[1, 1], dtype=dtypes.float32) 1367 wci = constant_op.constant(0.504355371, shape=[1], dtype=dtypes.float32) 1368 wcf = constant_op.constant(0.504355371, shape=[1], dtype=dtypes.float32) 1369 wco = constant_op.constant(0.504355371, shape=[1], dtype=dtypes.float32) 1370 b = constant_op.constant(0.504355371, shape=[1], dtype=dtypes.float32) 1371 i = constant_op.constant(0.504355371, shape=[1, 1, 1], dtype=dtypes.float32) 1372 cs = constant_op.constant( 1373 0.504355371, shape=[1, 1, 1], dtype=dtypes.float32) 1374 f = constant_op.constant(0.504355371, shape=[1, 1, 1], dtype=dtypes.float32) 1375 o = constant_op.constant(0.504355371, shape=[1, 1, 1], dtype=dtypes.float32) 1376 ci = constant_op.constant( 1377 0.504355371, shape=[1, 1, 1], dtype=dtypes.float32) 1378 co = constant_op.constant( 1379 0.504355371, shape=[1, 1, 1], dtype=dtypes.float32) 1380 h = constant_op.constant(0.504355371, shape=[1, 1, 1], dtype=dtypes.float32) 1381 cs_grad = constant_op.constant( 1382 0.504355371, shape=[1, 1, 1], dtype=dtypes.float32) 1383 h_grad = constant_op.constant( 1384 0.504355371, shape=[1, 1, 1], dtype=dtypes.float32) 1385 with self.assertRaisesRegex((ValueError, errors_impl.InvalidArgumentError), 1386 "must be rank"): 1387 self.evaluate( 1388 gen_rnn_ops.block_lstm_grad_v2( 1389 seq_len_max=seq_len_max, 1390 x=x, 1391 cs_prev=cs_prev, 1392 h_prev=h_prev, 1393 w=w, 1394 wci=wci, 1395 wcf=wcf, 1396 wco=wco, 1397 b=b, 1398 i=i, 1399 cs=cs, 1400 f=f, 1401 o=o, 1402 ci=ci, 1403 co=co, 1404 h=h, 1405 cs_grad=cs_grad, 1406 h_grad=h_grad, 1407 use_peephole=use_peephole)) 1408 1409 1410class BidirectionalRNNTest(test.TestCase): 1411 1412 def setUp(self): 1413 self._seed = 23489 1414 np.random.seed(self._seed) 1415 1416 def _createBidirectionalRNN(self, use_shape, use_sequence_length, scope=None): 1417 num_units = 3 1418 input_size = 5 1419 batch_size = 2 1420 max_length = 8 1421 1422 initializer = init_ops.random_uniform_initializer( 1423 -0.01, 0.01, seed=self._seed) 1424 sequence_length = array_ops.placeholder( 1425 dtypes.int64) if use_sequence_length else None 1426 cell_fw = rnn_cell.LSTMCell( 1427 num_units, input_size, initializer=initializer, state_is_tuple=False) 1428 cell_bw = rnn_cell.LSTMCell( 1429 num_units, input_size, initializer=initializer, state_is_tuple=False) 1430 inputs = max_length * [ 1431 array_ops.placeholder( 1432 dtypes.float32, 1433 shape=(batch_size, input_size) if use_shape else (None, input_size)) 1434 ] 1435 outputs, state_fw, state_bw = rnn.static_bidirectional_rnn( 1436 cell_fw, 1437 cell_bw, 1438 inputs, 1439 dtype=dtypes.float32, 1440 sequence_length=sequence_length, 1441 scope=scope) 1442 self.assertEqual(len(outputs), len(inputs)) 1443 for out in outputs: 1444 self.assertEqual(out.get_shape().as_list(), 1445 [batch_size if use_shape else None, 2 * num_units]) 1446 1447 input_value = np.random.randn(batch_size, input_size) 1448 outputs = array_ops.stack(outputs) 1449 1450 return input_value, inputs, outputs, state_fw, state_bw, sequence_length 1451 1452 def _testBidirectionalRNN(self, use_shape): 1453 with self.session(graph=ops.Graph()) as sess: 1454 input_value, inputs, outputs, state_fw, state_bw, sequence_length = ( 1455 self._createBidirectionalRNN(use_shape, True)) 1456 variables_lib.global_variables_initializer().run() 1457 # Run with pre-specified sequence length of 2, 3 1458 out, s_fw, s_bw = sess.run( 1459 [outputs, state_fw, state_bw], 1460 feed_dict={ 1461 inputs[0]: input_value, 1462 sequence_length: [2, 3] 1463 }) 1464 1465 # Since the forward and backward LSTM cells were initialized with the 1466 # same parameters, the forward and backward output has to be the same, 1467 # but reversed in time. The format is output[time][batch][depth], and 1468 # due to depth concatenation (as num_units=3 for both RNNs): 1469 # - forward output: out[][][depth] for 0 <= depth < 3 1470 # - backward output: out[][][depth] for 4 <= depth < 6 1471 # 1472 # First sequence in batch is length=2 1473 # Check that the time=0 forward output is equal to time=1 backward output 1474 self.assertAllClose(out[0][0][0], out[1][0][3]) 1475 self.assertAllClose(out[0][0][1], out[1][0][4]) 1476 self.assertAllClose(out[0][0][2], out[1][0][5]) 1477 # Check that the time=1 forward output is equal to time=0 backward output 1478 self.assertAllClose(out[1][0][0], out[0][0][3]) 1479 self.assertAllClose(out[1][0][1], out[0][0][4]) 1480 self.assertAllClose(out[1][0][2], out[0][0][5]) 1481 1482 # Second sequence in batch is length=3 1483 # Check that the time=0 forward output is equal to time=2 backward output 1484 self.assertAllClose(out[0][1][0], out[2][1][3]) 1485 self.assertAllClose(out[0][1][1], out[2][1][4]) 1486 self.assertAllClose(out[0][1][2], out[2][1][5]) 1487 # Check that the time=1 forward output is equal to time=1 backward output 1488 self.assertAllClose(out[1][1][0], out[1][1][3]) 1489 self.assertAllClose(out[1][1][1], out[1][1][4]) 1490 self.assertAllClose(out[1][1][2], out[1][1][5]) 1491 # Check that the time=2 forward output is equal to time=0 backward output 1492 self.assertAllClose(out[2][1][0], out[0][1][3]) 1493 self.assertAllClose(out[2][1][1], out[0][1][4]) 1494 self.assertAllClose(out[2][1][2], out[0][1][5]) 1495 # Via the reasoning above, the forward and backward final state should be 1496 # exactly the same 1497 self.assertAllClose(s_fw, s_bw) 1498 1499 def _testBidirectionalRNNWithoutSequenceLength(self, use_shape): 1500 with self.session(graph=ops.Graph()) as sess: 1501 input_value, inputs, outputs, state_fw, state_bw, _ = ( 1502 self._createBidirectionalRNN(use_shape, False)) 1503 variables_lib.global_variables_initializer().run() 1504 out, s_fw, s_bw = sess.run( 1505 [outputs, state_fw, state_bw], feed_dict={ 1506 inputs[0]: input_value 1507 }) 1508 1509 # Since the forward and backward LSTM cells were initialized with the 1510 # same parameters, the forward and backward output has to be the same, 1511 # but reversed in time. The format is output[time][batch][depth], and 1512 # due to depth concatenation (as num_units=3 for both RNNs): 1513 # - forward output: out[][][depth] for 0 <= depth < 3 1514 # - backward output: out[][][depth] for 4 <= depth < 6 1515 # 1516 # Both sequences in batch are length=8. Check that the time=i 1517 # forward output is equal to time=8-1-i backward output 1518 for i in range(8): 1519 self.assertAllClose(out[i][0][0:3], out[8 - 1 - i][0][3:6]) 1520 self.assertAllClose(out[i][1][0:3], out[8 - 1 - i][1][3:6]) 1521 # Via the reasoning above, the forward and backward final state should be 1522 # exactly the same 1523 self.assertAllClose(s_fw, s_bw) 1524 1525 @test_util.run_v1_only("b/124229375") 1526 def testBidirectionalRNN(self): 1527 self._testBidirectionalRNN(use_shape=False) 1528 self._testBidirectionalRNN(use_shape=True) 1529 1530 @test_util.run_v1_only("b/124229375") 1531 def testBidirectionalRNNWithoutSequenceLength(self): 1532 self._testBidirectionalRNNWithoutSequenceLength(use_shape=False) 1533 self._testBidirectionalRNNWithoutSequenceLength(use_shape=True) 1534 1535 def _createBidirectionalDynamicRNN(self, 1536 use_shape, 1537 use_state_tuple, 1538 use_time_major, 1539 use_sequence_length, 1540 scope=None): 1541 num_units = 3 1542 input_size = 5 1543 batch_size = 2 1544 max_length = 8 1545 1546 initializer = init_ops.random_uniform_initializer( 1547 -0.01, 0.01, seed=self._seed) 1548 sequence_length = ( 1549 array_ops.placeholder(dtypes.int64) if use_sequence_length else None) 1550 cell_fw = rnn_cell.LSTMCell( 1551 num_units, initializer=initializer, state_is_tuple=use_state_tuple) 1552 cell_bw = rnn_cell.LSTMCell( 1553 num_units, initializer=initializer, state_is_tuple=use_state_tuple) 1554 inputs = max_length * [ 1555 array_ops.placeholder( 1556 dtypes.float32, 1557 shape=(batch_size if use_shape else None, input_size)) 1558 ] 1559 inputs_c = array_ops.stack(inputs) 1560 if not use_time_major: 1561 inputs_c = array_ops.transpose(inputs_c, [1, 0, 2]) 1562 outputs, states = rnn.bidirectional_dynamic_rnn( 1563 cell_fw, 1564 cell_bw, 1565 inputs_c, 1566 sequence_length, 1567 dtype=dtypes.float32, 1568 time_major=use_time_major, 1569 scope=scope) 1570 outputs = array_ops.concat(outputs, 2) 1571 state_fw, state_bw = states 1572 outputs_shape = [None, max_length, 2 * num_units] 1573 if use_shape: 1574 outputs_shape[0] = batch_size 1575 if use_time_major: 1576 outputs_shape[0], outputs_shape[1] = outputs_shape[1], outputs_shape[0] 1577 self.assertEqual(outputs.get_shape().as_list(), outputs_shape) 1578 1579 input_value = np.random.randn(batch_size, input_size) 1580 1581 return input_value, inputs, outputs, state_fw, state_bw, sequence_length 1582 1583 def _testBidirectionalDynamicRNN(self, use_shape, use_state_tuple, 1584 use_time_major, use_sequence_length): 1585 with self.session(graph=ops.Graph()) as sess: 1586 input_value, inputs, outputs, state_fw, state_bw, sequence_length = ( 1587 self._createBidirectionalDynamicRNN( 1588 use_shape, use_state_tuple, use_time_major, use_sequence_length)) 1589 variables_lib.global_variables_initializer().run() 1590 # Run with pre-specified sequence length of 2, 3 1591 feed_dict = ({sequence_length: [2, 3]} if use_sequence_length else {}) 1592 feed_dict.update({inputs[0]: input_value}) 1593 if use_state_tuple: 1594 out, c_fw, m_fw, c_bw, m_bw = sess.run( 1595 [outputs, state_fw[0], state_fw[1], state_bw[0], state_bw[1]], 1596 feed_dict=feed_dict) 1597 s_fw = (c_fw, m_fw) 1598 s_bw = (c_bw, m_bw) 1599 else: 1600 feed_dict.update({inputs[0]: input_value}) 1601 out, s_fw, s_bw = sess.run( 1602 [outputs, state_fw, state_bw], feed_dict=feed_dict) 1603 1604 # Since the forward and backward LSTM cells were initialized with the 1605 # same parameters, the forward and backward output has to be the same, 1606 # but reversed in time. The format is output[time][batch][depth], and 1607 # due to depth concatenation (as num_units=3 for both RNNs): 1608 # - forward output: out[][][depth] for 0 <= depth < 3 1609 # - backward output: out[][][depth] for 4 <= depth < 6 1610 # 1611 if not use_time_major: 1612 out = np.swapaxes(out, 0, 1) 1613 1614 if use_sequence_length: 1615 # First sequence in batch is length=2 1616 # Check that the t=0 forward output is equal to t=1 backward output 1617 self.assertEqual(out[0][0][0], out[1][0][3]) 1618 self.assertEqual(out[0][0][1], out[1][0][4]) 1619 self.assertEqual(out[0][0][2], out[1][0][5]) 1620 # Check that the t=1 forward output is equal to t=0 backward output 1621 self.assertEqual(out[1][0][0], out[0][0][3]) 1622 self.assertEqual(out[1][0][1], out[0][0][4]) 1623 self.assertEqual(out[1][0][2], out[0][0][5]) 1624 1625 # Second sequence in batch is length=3 1626 # Check that the t=0 forward output is equal to t=2 backward output 1627 self.assertEqual(out[0][1][0], out[2][1][3]) 1628 self.assertEqual(out[0][1][1], out[2][1][4]) 1629 self.assertEqual(out[0][1][2], out[2][1][5]) 1630 # Check that the t=1 forward output is equal to t=1 backward output 1631 self.assertEqual(out[1][1][0], out[1][1][3]) 1632 self.assertEqual(out[1][1][1], out[1][1][4]) 1633 self.assertEqual(out[1][1][2], out[1][1][5]) 1634 # Check that the t=2 forward output is equal to t=0 backward output 1635 self.assertEqual(out[2][1][0], out[0][1][3]) 1636 self.assertEqual(out[2][1][1], out[0][1][4]) 1637 self.assertEqual(out[2][1][2], out[0][1][5]) 1638 # Via the reasoning above, the forward and backward final state should 1639 # be exactly the same 1640 self.assertAllClose(s_fw, s_bw) 1641 else: # not use_sequence_length 1642 max_length = 8 # from createBidirectionalDynamicRNN 1643 for t in range(max_length): 1644 self.assertAllEqual(out[t, :, 0:3], out[max_length - t - 1, :, 3:6]) 1645 self.assertAllClose(s_fw, s_bw) 1646 1647 @test_util.run_v1_only("b/124229375") 1648 def testBidirectionalDynamicRNN(self): 1649 # Generate 2^5 option values 1650 # from [True, True, True, True, True] to [False, False, False, False, False] 1651 options = itertools.product([True, False], repeat=4) 1652 for option in options: 1653 self._testBidirectionalDynamicRNN( 1654 use_shape=option[0], 1655 use_state_tuple=option[1], 1656 use_time_major=option[2], 1657 use_sequence_length=option[3]) 1658 1659 def _testScope(self, factory, prefix="prefix", use_outer_scope=True): 1660 # REMARKS: factory(scope) is a function accepting a scope 1661 # as an argument, such scope can be None, a string 1662 # or a VariableScope instance. 1663 with self.session(graph=ops.Graph()): 1664 if use_outer_scope: 1665 with variable_scope.variable_scope(prefix) as scope: 1666 factory(scope) 1667 else: 1668 factory(prefix) 1669 1670 # check that all the variables names starts 1671 # with the proper scope. 1672 variables_lib.global_variables_initializer() 1673 all_vars = variables_lib.global_variables() 1674 prefix = prefix or "bidirectional_rnn" 1675 scope_vars = [v for v in all_vars if v.name.startswith(prefix + "/")] 1676 tf_logging.info("BiRNN with scope: %s (%s)" % 1677 (prefix, "scope" if use_outer_scope else "str")) 1678 for v in scope_vars: 1679 tf_logging.info(v.name) 1680 self.assertEqual(len(scope_vars), len(all_vars)) 1681 1682 @test_util.run_v1_only("b/124229375") 1683 def testBidirectionalRNNScope(self): 1684 1685 def factory(scope): 1686 return self._createBidirectionalRNN( 1687 use_shape=True, use_sequence_length=True, scope=scope) 1688 1689 self._testScope(factory, use_outer_scope=True) 1690 self._testScope(factory, use_outer_scope=False) 1691 self._testScope(factory, prefix=None, use_outer_scope=False) 1692 1693 @test_util.run_v1_only("b/124229375") 1694 def testBidirectionalDynamicRNNScope(self): 1695 1696 def get_factory(use_time_major): 1697 1698 def factory(scope): 1699 return self._createBidirectionalDynamicRNN( 1700 use_shape=True, 1701 use_state_tuple=True, 1702 use_sequence_length=True, 1703 use_time_major=use_time_major, 1704 scope=scope) 1705 1706 return factory 1707 1708 self._testScope(get_factory(True), use_outer_scope=True) 1709 self._testScope(get_factory(True), use_outer_scope=False) 1710 self._testScope(get_factory(True), prefix=None, use_outer_scope=False) 1711 self._testScope(get_factory(False), use_outer_scope=True) 1712 self._testScope(get_factory(False), use_outer_scope=False) 1713 self._testScope(get_factory(False), prefix=None, use_outer_scope=False) 1714 1715 1716class MultiDimensionalLSTMTest(test.TestCase): 1717 1718 def setUp(self): 1719 self._seed = 23489 1720 np.random.seed(self._seed) 1721 1722 @test_util.run_v1_only("b/124229375") 1723 def testMultiDimensionalLSTMAllRNNContainers(self): 1724 feature_dims = (3, 4, 5) 1725 input_size = feature_dims 1726 batch_size = 2 1727 max_length = 8 1728 sequence_length = [4, 6] 1729 with self.session(graph=ops.Graph()) as sess: 1730 inputs = max_length * [ 1731 array_ops.placeholder(dtypes.float32, shape=(None,) + input_size) 1732 ] 1733 inputs_using_dim = max_length * [ 1734 array_ops.placeholder( 1735 dtypes.float32, shape=(batch_size,) + input_size) 1736 ] 1737 inputs_c = array_ops.stack(inputs) 1738 # Create a cell for the whole test. This is fine because the cell has no 1739 # variables. 1740 cell = DummyMultiDimensionalLSTM(feature_dims) 1741 state_saver = TestStateSaver(batch_size, input_size) 1742 outputs_static, state_static = rnn.static_rnn( 1743 cell, inputs, dtype=dtypes.float32, sequence_length=sequence_length) 1744 outputs_dynamic, state_dynamic = rnn.dynamic_rnn( 1745 cell, 1746 inputs_c, 1747 dtype=dtypes.float32, 1748 time_major=True, 1749 sequence_length=sequence_length) 1750 outputs_bid, state_fw, state_bw = rnn.static_bidirectional_rnn( 1751 cell, 1752 cell, 1753 inputs_using_dim, 1754 dtype=dtypes.float32, 1755 sequence_length=sequence_length) 1756 outputs_sav, state_sav = rnn.static_state_saving_rnn( 1757 cell, 1758 inputs_using_dim, 1759 sequence_length=sequence_length, 1760 state_saver=state_saver, 1761 state_name=("h", "c")) 1762 1763 self.assertEqual(outputs_dynamic.get_shape().as_list(), 1764 inputs_c.get_shape().as_list()) 1765 for out, inp in zip(outputs_static, inputs): 1766 self.assertEqual(out.get_shape().as_list(), inp.get_shape().as_list()) 1767 for out, inp in zip(outputs_bid, inputs_using_dim): 1768 input_shape_list = inp.get_shape().as_list() 1769 # fwd and bwd activations are concatenated along the second dim. 1770 input_shape_list[1] *= 2 1771 self.assertEqual(out.get_shape().as_list(), input_shape_list) 1772 1773 variables_lib.global_variables_initializer().run() 1774 1775 input_total_size = (batch_size,) + input_size 1776 input_value = np.random.randn(*input_total_size) 1777 outputs_static_v = sess.run( 1778 outputs_static, feed_dict={ 1779 inputs[0]: input_value 1780 }) 1781 outputs_dynamic_v = sess.run( 1782 outputs_dynamic, feed_dict={ 1783 inputs[0]: input_value 1784 }) 1785 outputs_bid_v = sess.run( 1786 outputs_bid, feed_dict={ 1787 inputs_using_dim[0]: input_value 1788 }) 1789 outputs_sav_v = sess.run( 1790 outputs_sav, feed_dict={ 1791 inputs_using_dim[0]: input_value 1792 }) 1793 1794 self.assertAllEqual(outputs_static_v, outputs_dynamic_v) 1795 self.assertAllEqual(outputs_static_v, outputs_sav_v) 1796 outputs_static_array = np.array(outputs_static_v) 1797 outputs_static_array_double = np.concatenate( 1798 (outputs_static_array, outputs_static_array), axis=2) 1799 outputs_bid_array = np.array(outputs_bid_v) 1800 self.assertAllEqual(outputs_static_array_double, outputs_bid_array) 1801 1802 state_static_v = sess.run( 1803 state_static, feed_dict={ 1804 inputs[0]: input_value 1805 }) 1806 state_dynamic_v = sess.run( 1807 state_dynamic, feed_dict={ 1808 inputs[0]: input_value 1809 }) 1810 state_bid_fw_v = sess.run( 1811 state_fw, feed_dict={ 1812 inputs_using_dim[0]: input_value 1813 }) 1814 state_bid_bw_v = sess.run( 1815 state_bw, feed_dict={ 1816 inputs_using_dim[0]: input_value 1817 }) 1818 state_sav_v = sess.run( 1819 state_sav, feed_dict={ 1820 inputs_using_dim[0]: input_value 1821 }) 1822 self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_dynamic_v)) 1823 self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_sav_v)) 1824 self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_bid_fw_v)) 1825 self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_bid_bw_v)) 1826 1827 1828class NestedLSTMTest(test.TestCase): 1829 1830 def setUp(self): 1831 self._seed = 23489 1832 np.random.seed(self._seed) 1833 1834 @test_util.run_v1_only("b/124229375") 1835 def testNestedIOLSTMAllRNNContainers(self): 1836 input_size = 5 1837 batch_size = 2 1838 state_size = 6 1839 max_length = 8 1840 sequence_length = [4, 6] 1841 with self.session(graph=ops.Graph()) as sess: 1842 state_saver = TestStateSaver(batch_size, state_size) 1843 single_input = (array_ops.placeholder( 1844 dtypes.float32, shape=(None, input_size)), 1845 array_ops.placeholder( 1846 dtypes.float32, shape=(None, input_size))) 1847 inputs = max_length * [single_input] 1848 inputs_c = (array_ops.stack([input_[0] for input_ in inputs]), 1849 array_ops.stack([input_[1] for input_ in inputs])) 1850 single_input_using_dim = (array_ops.placeholder( 1851 dtypes.float32, shape=(batch_size, input_size)), 1852 array_ops.placeholder( 1853 dtypes.float32, 1854 shape=(batch_size, input_size))) 1855 inputs_using_dim = max_length * [single_input_using_dim] 1856 1857 # Create a cell for the whole test. This is fine because the cell has no 1858 # variables. 1859 cell = NestedRNNCell() 1860 outputs_dynamic, state_dynamic = rnn.dynamic_rnn( 1861 cell, 1862 inputs_c, 1863 dtype=dtypes.float32, 1864 time_major=True, 1865 sequence_length=sequence_length) 1866 outputs_static, state_static = rnn.static_rnn( 1867 cell, inputs, dtype=dtypes.float32, sequence_length=sequence_length) 1868 outputs_bid, state_fw, state_bw = rnn.static_bidirectional_rnn( 1869 cell, 1870 cell, 1871 inputs_using_dim, 1872 dtype=dtypes.float32, 1873 sequence_length=sequence_length) 1874 outputs_sav, state_sav = rnn.static_state_saving_rnn( 1875 cell, 1876 inputs_using_dim, 1877 sequence_length=sequence_length, 1878 state_saver=state_saver, 1879 state_name=("h", "c")) 1880 1881 def _assert_same_shape(input1, input2, double=False): 1882 flat_input1 = nest.flatten(input1) 1883 flat_input2 = nest.flatten(input2) 1884 for inp1, inp2 in zip(flat_input1, flat_input2): 1885 input_shape = inp1.get_shape().as_list() 1886 if double: 1887 input_shape[1] *= 2 1888 self.assertEqual(input_shape, inp2.get_shape().as_list()) 1889 1890 _assert_same_shape(inputs_c, outputs_dynamic) 1891 _assert_same_shape(inputs, outputs_static) 1892 _assert_same_shape(inputs_using_dim, outputs_sav) 1893 _assert_same_shape(inputs_using_dim, outputs_bid, double=True) 1894 1895 variables_lib.global_variables_initializer().run() 1896 1897 input_total_size = (batch_size, input_size) 1898 input_value = (np.random.randn(*input_total_size), 1899 np.random.randn(*input_total_size)) 1900 outputs_dynamic_v = sess.run( 1901 outputs_dynamic, feed_dict={ 1902 single_input: input_value 1903 }) 1904 outputs_static_v = sess.run( 1905 outputs_static, feed_dict={ 1906 single_input: input_value 1907 }) 1908 outputs_sav_v = sess.run( 1909 outputs_sav, feed_dict={ 1910 single_input_using_dim: input_value 1911 }) 1912 outputs_bid_v = sess.run( 1913 outputs_bid, feed_dict={ 1914 single_input_using_dim: input_value 1915 }) 1916 1917 self.assertAllEqual(outputs_static_v, 1918 np.transpose(outputs_dynamic_v, (1, 0, 2, 3))) 1919 self.assertAllEqual(outputs_static_v, outputs_sav_v) 1920 outputs_static_array = np.array(outputs_static_v) 1921 outputs_static_array_double = np.concatenate( 1922 (outputs_static_array, outputs_static_array), axis=3) 1923 outputs_bid_array = np.array(outputs_bid_v) 1924 self.assertAllEqual(outputs_static_array_double, outputs_bid_array) 1925 1926 state_dynamic_v = sess.run( 1927 state_dynamic, feed_dict={ 1928 single_input: input_value 1929 }) 1930 state_static_v = sess.run( 1931 state_static, feed_dict={ 1932 single_input: input_value 1933 }) 1934 state_bid_fw_v = sess.run( 1935 state_fw, feed_dict={ 1936 single_input_using_dim: input_value 1937 }) 1938 state_bid_bw_v = sess.run( 1939 state_bw, feed_dict={ 1940 single_input_using_dim: input_value 1941 }) 1942 state_sav_v = sess.run( 1943 state_sav, feed_dict={ 1944 single_input_using_dim: input_value 1945 }) 1946 self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_dynamic_v)) 1947 self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_sav_v)) 1948 self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_bid_fw_v)) 1949 self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_bid_bw_v)) 1950 1951 1952class StateSaverRNNTest(test.TestCase): 1953 1954 def setUp(self): 1955 self._seed = 23489 1956 np.random.seed(self._seed) 1957 1958 def _factory(self, scope, state_saver): 1959 num_units = state_saver.state_size // 2 1960 batch_size = state_saver.batch_size 1961 input_size = 5 1962 max_length = 8 1963 initializer = init_ops.random_uniform_initializer( 1964 -0.01, 0.01, seed=self._seed) 1965 cell = rnn_cell.LSTMCell( 1966 num_units, 1967 use_peepholes=False, 1968 initializer=initializer, 1969 state_is_tuple=False) 1970 inputs = max_length * [ 1971 array_ops.zeros(dtype=dtypes.float32, shape=(batch_size, input_size)) 1972 ] 1973 out, state = rnn.static_state_saving_rnn( 1974 cell, 1975 inputs, 1976 state_saver=state_saver, 1977 state_name="save_lstm", 1978 scope=scope) 1979 return out, state, state_saver 1980 1981 def _testScope(self, prefix="prefix", use_outer_scope=True): 1982 num_units = 3 1983 batch_size = 2 1984 state_saver = TestStateSaver(batch_size, 2 * num_units) 1985 1986 with self.session(graph=ops.Graph()): 1987 if use_outer_scope: 1988 with variable_scope.variable_scope(prefix) as scope: 1989 self._factory(scope=scope, state_saver=state_saver) 1990 else: 1991 self._factory(scope=prefix, state_saver=state_saver) 1992 variables_lib.global_variables_initializer() 1993 1994 # check that all the variables names starts 1995 # with the proper scope. 1996 all_vars = variables_lib.global_variables() 1997 prefix = prefix or "rnn" 1998 scope_vars = [v for v in all_vars if v.name.startswith(prefix + "/")] 1999 tf_logging.info("RNN with scope: %s (%s)" % 2000 (prefix, "scope" if use_outer_scope else "str")) 2001 for v in scope_vars: 2002 tf_logging.info(v.name) 2003 self.assertEqual(len(scope_vars), len(all_vars)) 2004 2005 def testStateSaverRNNScope(self): 2006 self._testScope(use_outer_scope=True) 2007 self._testScope(use_outer_scope=False) 2008 self._testScope(prefix=None, use_outer_scope=False) 2009 2010 def testStateSaverCallsSaveState(self): 2011 """Test that number of calls to state and save_state is equal. 2012 2013 Test if the order of actual evaluating or skipping evaluation of out, 2014 state tensors, which are the output tensors from static_state_saving_rnn, 2015 have influence on number of calls to save_state and state methods of 2016 state_saver object (the number of calls should be same.) 2017 """ 2018 self.skipTest("b/124196246 Breakage for sess.run([out, ...]): 2 != 1") 2019 2020 num_units = 3 2021 batch_size = 2 2022 state_saver = TestStateSaverWithCounters(batch_size, 2 * num_units) 2023 out, state, state_saver = self._factory(scope=None, state_saver=state_saver) 2024 2025 with self.cached_session() as sess: 2026 sess.run(variables_lib.global_variables_initializer()) 2027 sess.run(variables_lib.local_variables_initializer()) 2028 2029 _, _, num_state_calls, num_save_state_calls = sess.run([ 2030 out, 2031 state, 2032 state_saver.num_state_calls, 2033 state_saver.num_save_state_calls]) 2034 self.assertEqual(num_state_calls, num_save_state_calls) 2035 2036 _, num_state_calls, num_save_state_calls = sess.run([ 2037 out, 2038 state_saver.num_state_calls, 2039 state_saver.num_save_state_calls]) 2040 self.assertEqual(num_state_calls, num_save_state_calls) 2041 2042 _, num_state_calls, num_save_state_calls = sess.run([ 2043 state, 2044 state_saver.num_state_calls, 2045 state_saver.num_save_state_calls]) 2046 self.assertEqual(num_state_calls, num_save_state_calls) 2047 2048class GRUTest(test.TestCase): 2049 2050 def setUp(self): 2051 self._seed = 23489 2052 np.random.seed(self._seed) 2053 2054 @test_util.run_v1_only("b/124229375") 2055 def testDynamic(self): 2056 time_steps = 8 2057 num_units = 3 2058 input_size = 5 2059 batch_size = 2 2060 2061 input_values = np.random.randn(time_steps, batch_size, input_size) 2062 2063 sequence_length = np.random.randint(0, time_steps, size=batch_size) 2064 2065 with self.session(graph=ops.Graph()) as sess: 2066 concat_inputs = array_ops.placeholder( 2067 dtypes.float32, shape=(time_steps, batch_size, input_size)) 2068 2069 cell = rnn_cell.GRUCell(num_units=num_units) 2070 2071 with variable_scope.variable_scope("dynamic_scope"): 2072 outputs_dynamic, state_dynamic = rnn.dynamic_rnn( 2073 cell, 2074 inputs=concat_inputs, 2075 sequence_length=sequence_length, 2076 time_major=True, 2077 dtype=dtypes.float32) 2078 2079 feeds = {concat_inputs: input_values} 2080 2081 # Initialize 2082 variables_lib.global_variables_initializer().run(feed_dict=feeds) 2083 2084 sess.run([outputs_dynamic, state_dynamic], feed_dict=feeds) 2085 2086 def _testScope(self, factory, prefix="prefix", use_outer_scope=True): 2087 with self.session(graph=ops.Graph()): 2088 if use_outer_scope: 2089 with variable_scope.variable_scope(prefix) as scope: 2090 factory(scope) 2091 else: 2092 factory(prefix) 2093 variables_lib.global_variables_initializer() 2094 2095 # check that all the variables names starts 2096 # with the proper scope. 2097 all_vars = variables_lib.global_variables() 2098 prefix = prefix or "rnn" 2099 scope_vars = [v for v in all_vars if v.name.startswith(prefix + "/")] 2100 tf_logging.info("RNN with scope: %s (%s)" % 2101 (prefix, "scope" if use_outer_scope else "str")) 2102 for v in scope_vars: 2103 tf_logging.info(v.name) 2104 self.assertEqual(len(scope_vars), len(all_vars)) 2105 2106 @test_util.run_v1_only("b/124229375") 2107 def testDynamicScope(self): 2108 time_steps = 8 2109 num_units = 3 2110 input_size = 5 2111 batch_size = 2 2112 sequence_length = np.random.randint(0, time_steps, size=batch_size) 2113 2114 def factory(scope): 2115 concat_inputs = array_ops.placeholder( 2116 dtypes.float32, shape=(time_steps, batch_size, input_size)) 2117 cell = rnn_cell.GRUCell(num_units=num_units) 2118 return rnn.dynamic_rnn( 2119 cell, 2120 inputs=concat_inputs, 2121 sequence_length=sequence_length, 2122 time_major=True, 2123 dtype=dtypes.float32, 2124 scope=scope) 2125 2126 self._testScope(factory, use_outer_scope=True) 2127 self._testScope(factory, use_outer_scope=False) 2128 self._testScope(factory, prefix=None, use_outer_scope=False) 2129 2130 2131class RawRNNTest(test.TestCase): 2132 2133 def setUp(self): 2134 self._seed = 23489 2135 np.random.seed(self._seed) 2136 2137 @test_util.run_v1_only("b/124229375") 2138 def _testRawRNN(self, max_time): 2139 with self.session(graph=ops.Graph()) as sess: 2140 batch_size = 16 2141 input_depth = 4 2142 num_units = 3 2143 2144 inputs = array_ops.placeholder( 2145 shape=(max_time, batch_size, input_depth), dtype=dtypes.float32) 2146 sequence_length = array_ops.placeholder( 2147 shape=(batch_size,), dtype=dtypes.int32) 2148 inputs_ta = tensor_array_ops.TensorArray( 2149 dtype=dtypes.float32, size=array_ops.shape(inputs)[0]) 2150 inputs_ta = inputs_ta.unstack(inputs) 2151 2152 cell = rnn_cell.LSTMCell(num_units, state_is_tuple=True) 2153 2154 def loop_fn(time_, cell_output, cell_state, unused_loop_state): 2155 emit_output = cell_output # == None for time == 0 2156 if cell_output is None: # time == 0 2157 next_state = cell.zero_state(batch_size, dtypes.float32) 2158 else: 2159 next_state = cell_state # copy state through 2160 elements_finished = (time_ >= sequence_length) 2161 finished = math_ops.reduce_all(elements_finished) 2162 # For the very final iteration, we must emit a dummy input 2163 next_input = control_flow_ops.cond( 2164 finished, 2165 lambda: array_ops.zeros([batch_size, input_depth], dtype=dtypes.float32), 2166 lambda: inputs_ta.read(time_)) 2167 return (elements_finished, next_input, next_state, emit_output, None) 2168 2169 reuse_scope = variable_scope.get_variable_scope() 2170 2171 outputs_ta, final_state, _ = rnn.raw_rnn(cell, loop_fn, scope=reuse_scope) 2172 outputs = outputs_ta.stack() 2173 2174 reuse_scope.reuse_variables() 2175 outputs_dynamic_rnn, final_state_dynamic_rnn = rnn.dynamic_rnn( 2176 cell, 2177 inputs, 2178 time_major=True, 2179 dtype=dtypes.float32, 2180 sequence_length=sequence_length, 2181 scope=reuse_scope) 2182 2183 variables = variables_lib.trainable_variables() 2184 gradients = gradients_impl.gradients([outputs, final_state], 2185 [inputs] + variables) 2186 gradients_dynamic_rnn = gradients_impl.gradients( 2187 [outputs_dynamic_rnn, final_state_dynamic_rnn], [inputs] + variables) 2188 2189 variables_lib.global_variables_initializer().run() 2190 2191 rand_input = np.random.randn(max_time, batch_size, input_depth) 2192 if max_time == 0: 2193 rand_seq_len = np.zeros(batch_size) 2194 else: 2195 rand_seq_len = np.random.randint(max_time, size=batch_size) 2196 2197 # To ensure same output lengths for dynamic_rnn and raw_rnn 2198 rand_seq_len[0] = max_time 2199 2200 (outputs_val, outputs_dynamic_rnn_val, final_state_val, 2201 final_state_dynamic_rnn_val) = sess.run( 2202 [outputs, outputs_dynamic_rnn, final_state, final_state_dynamic_rnn], 2203 feed_dict={ 2204 inputs: rand_input, 2205 sequence_length: rand_seq_len 2206 }) 2207 2208 self.assertAllClose(outputs_dynamic_rnn_val, outputs_val) 2209 self.assertAllClose(final_state_dynamic_rnn_val, final_state_val) 2210 2211 # NOTE: Because with 0 time steps, raw_rnn does not have shape 2212 # information about the input, it is impossible to perform 2213 # gradients comparisons as the gradients eval will fail. So 2214 # this case skips the gradients test. 2215 if max_time > 0: 2216 self.assertEqual(len(gradients), len(gradients_dynamic_rnn)) 2217 gradients_val = sess.run( 2218 gradients, 2219 feed_dict={ 2220 inputs: rand_input, 2221 sequence_length: rand_seq_len 2222 }) 2223 gradients_dynamic_rnn_val = sess.run( 2224 gradients_dynamic_rnn, 2225 feed_dict={ 2226 inputs: rand_input, 2227 sequence_length: rand_seq_len 2228 }) 2229 self.assertEqual(len(gradients_val), len(gradients_dynamic_rnn_val)) 2230 input_gradients_val = gradients_val[0] 2231 input_gradients_dynamic_rnn_val = gradients_dynamic_rnn_val[0] 2232 self.assertAllClose(input_gradients_val, 2233 input_gradients_dynamic_rnn_val) 2234 for i in range(1, len(gradients_val)): 2235 self.assertAllClose(gradients_dynamic_rnn_val[i], gradients_val[i]) 2236 2237 @test_util.run_v1_only("b/124229375") 2238 def testRawRNNZeroLength(self): 2239 # NOTE: Because with 0 time steps, raw_rnn does not have shape 2240 # information about the input, it is impossible to perform 2241 # gradients comparisons as the gradients eval will fail. So this 2242 # case skips the gradients test. 2243 self._testRawRNN(max_time=0) 2244 2245 def testRawRNN(self): 2246 self._testRawRNN(max_time=10) 2247 2248 @test_util.run_v1_only("b/124229375") 2249 def testLoopState(self): 2250 with self.session(graph=ops.Graph()): 2251 max_time = 10 2252 batch_size = 16 2253 input_depth = 4 2254 num_units = 3 2255 2256 inputs = np.random.randn(max_time, batch_size, input_depth) 2257 inputs_ta = tensor_array_ops.TensorArray( 2258 dtype=dtypes.float32, size=array_ops.shape(inputs)[0]) 2259 inputs_ta = inputs_ta.unstack(inputs) 2260 2261 cell = rnn_cell.LSTMCell(num_units, state_is_tuple=True) 2262 2263 def loop_fn(time_, cell_output, cell_state, loop_state): 2264 if cell_output is None: 2265 loop_state = constant_op.constant([0]) 2266 next_state = cell.zero_state(batch_size, dtypes.float32) 2267 else: 2268 loop_state = array_ops.stack([array_ops.squeeze(loop_state) + 1]) 2269 next_state = cell_state 2270 emit_output = cell_output # == None for time == 0 2271 elements_finished = array_ops.tile([time_ >= max_time], [batch_size]) 2272 finished = math_ops.reduce_all(elements_finished) 2273 # For the very final iteration, we must emit a dummy input 2274 next_input = control_flow_ops.cond( 2275 finished, 2276 lambda: array_ops.zeros([batch_size, input_depth], dtype=dtypes.float32), 2277 lambda: inputs_ta.read(time_)) 2278 return (elements_finished, next_input, next_state, emit_output, 2279 loop_state) 2280 2281 r = rnn.raw_rnn(cell, loop_fn) 2282 loop_state = r[-1] 2283 self.assertEqual([10], self.evaluate(loop_state)) 2284 2285 @test_util.run_v1_only("b/124229375") 2286 def testLoopStateWithTensorArray(self): 2287 with self.session(graph=ops.Graph()): 2288 max_time = 4 2289 batch_size = 16 2290 input_depth = 4 2291 num_units = 3 2292 2293 inputs = np.random.randn(max_time, batch_size, input_depth) 2294 inputs_ta = tensor_array_ops.TensorArray( 2295 dtype=dtypes.float32, size=array_ops.shape(inputs)[0]) 2296 inputs_ta = inputs_ta.unstack(inputs) 2297 2298 cell = rnn_cell.LSTMCell(num_units, state_is_tuple=True) 2299 2300 def loop_fn(time_, cell_output, cell_state, loop_state): 2301 if cell_output is None: 2302 loop_state = tensor_array_ops.TensorArray( 2303 dynamic_size=True, 2304 size=0, 2305 dtype=dtypes.int32, 2306 clear_after_read=False) 2307 loop_state = loop_state.write(0, 1) 2308 next_state = cell.zero_state(batch_size, dtypes.float32) 2309 else: 2310 loop_state = loop_state.write(time_, 2311 loop_state.read(time_ - 1) + time_) 2312 next_state = cell_state 2313 emit_output = cell_output # == None for time == 0 2314 elements_finished = array_ops.tile([time_ >= max_time], [batch_size]) 2315 finished = math_ops.reduce_all(elements_finished) 2316 # For the very final iteration, we must emit a dummy input 2317 next_input = control_flow_ops.cond( 2318 finished, 2319 lambda: array_ops.zeros([batch_size, input_depth], dtype=dtypes.float32), 2320 lambda: inputs_ta.read(time_)) 2321 return (elements_finished, next_input, next_state, emit_output, 2322 loop_state) 2323 2324 r = rnn.raw_rnn(cell, loop_fn) 2325 loop_state = r[-1] 2326 loop_state = loop_state.stack() 2327 self.assertAllEqual([1, 2, 2 + 2, 4 + 3, 7 + 4], loop_state) 2328 2329 @test_util.run_v1_only("b/124229375") 2330 def testEmitDifferentStructureThanCellOutput(self): 2331 with self.session(graph=ops.Graph()) as sess: 2332 max_time = 10 2333 batch_size = 16 2334 input_depth = 4 2335 num_units = 3 2336 2337 inputs = np.random.randn(max_time, batch_size, input_depth) 2338 inputs_ta = tensor_array_ops.TensorArray( 2339 dtype=dtypes.float32, size=array_ops.shape(inputs)[0]) 2340 inputs_ta = inputs_ta.unstack(inputs) 2341 # Verify emit shapes may be unknown by feeding a placeholder that 2342 # determines an emit shape. 2343 unknown_dim = array_ops.placeholder(dtype=dtypes.int32) 2344 2345 cell = rnn_cell.LSTMCell(num_units, state_is_tuple=True) 2346 2347 def loop_fn(time_, cell_output, cell_state, _): 2348 if cell_output is None: 2349 emit_output = (array_ops.zeros([2, 3], dtype=dtypes.int32), 2350 array_ops.zeros([unknown_dim], dtype=dtypes.int64)) 2351 next_state = cell.zero_state(batch_size, dtypes.float32) 2352 else: 2353 emit_output = (array_ops.ones([batch_size, 2, 3], dtype=dtypes.int32), 2354 array_ops.ones( 2355 [batch_size, unknown_dim], dtype=dtypes.int64)) 2356 next_state = cell_state 2357 elements_finished = array_ops.tile([time_ >= max_time], [batch_size]) 2358 finished = math_ops.reduce_all(elements_finished) 2359 # For the very final iteration, we must emit a dummy input 2360 next_input = control_flow_ops.cond( 2361 finished, 2362 lambda: array_ops.zeros([batch_size, input_depth], dtype=dtypes.float32), 2363 lambda: inputs_ta.read(time_)) 2364 return (elements_finished, next_input, next_state, emit_output, None) 2365 2366 r = rnn.raw_rnn(cell, loop_fn) 2367 output_ta = r[0] 2368 self.assertEqual(2, len(output_ta)) 2369 self.assertEqual([dtypes.int32, dtypes.int64], 2370 [ta.dtype for ta in output_ta]) 2371 output = [ta.stack() for ta in output_ta] 2372 output_vals = sess.run(output, feed_dict={unknown_dim: 1}) 2373 self.assertAllEqual( 2374 np.ones((max_time, batch_size, 2, 3), np.int32), output_vals[0]) 2375 self.assertAllEqual( 2376 np.ones((max_time, batch_size, 1), np.int64), output_vals[1]) 2377 2378 def _testScope(self, factory, prefix="prefix", use_outer_scope=True): 2379 with self.session(graph=ops.Graph()): 2380 if use_outer_scope: 2381 with variable_scope.variable_scope(prefix) as scope: 2382 factory(scope) 2383 else: 2384 factory(prefix) 2385 variables_lib.global_variables_initializer() 2386 2387 # check that all the variables names starts 2388 # with the proper scope. 2389 all_vars = variables_lib.global_variables() 2390 prefix = prefix or "rnn" 2391 scope_vars = [v for v in all_vars if v.name.startswith(prefix + "/")] 2392 tf_logging.info("RNN with scope: %s (%s)" % 2393 (prefix, "scope" if use_outer_scope else "str")) 2394 for v in scope_vars: 2395 tf_logging.info(v.name) 2396 self.assertEqual(len(scope_vars), len(all_vars)) 2397 2398 @test_util.run_v1_only("b/124229375") 2399 def testRawRNNScope(self): 2400 max_time = 10 2401 batch_size = 16 2402 input_depth = 4 2403 num_units = 3 2404 2405 def factory(scope): 2406 inputs = array_ops.placeholder( 2407 shape=(max_time, batch_size, input_depth), dtype=dtypes.float32) 2408 sequence_length = array_ops.placeholder( 2409 shape=(batch_size,), dtype=dtypes.int32) 2410 inputs_ta = tensor_array_ops.TensorArray( 2411 dtype=dtypes.float32, size=array_ops.shape(inputs)[0]) 2412 inputs_ta = inputs_ta.unstack(inputs) 2413 2414 cell = rnn_cell.LSTMCell(num_units, state_is_tuple=True) 2415 2416 def loop_fn(time_, cell_output, cell_state, unused_loop_state): 2417 emit_output = cell_output # == None for time == 0 2418 if cell_output is None: # time == 0 2419 next_state = cell.zero_state(batch_size, dtypes.float32) 2420 else: 2421 next_state = cell_state 2422 2423 elements_finished = (time_ >= sequence_length) 2424 finished = math_ops.reduce_all(elements_finished) 2425 # For the very final iteration, we must emit a dummy input 2426 next_input = control_flow_ops.cond( 2427 finished, 2428 lambda: array_ops.zeros([batch_size, input_depth], dtype=dtypes.float32), 2429 lambda: inputs_ta.read(time_)) 2430 return (elements_finished, next_input, next_state, emit_output, None) 2431 2432 return rnn.raw_rnn(cell, loop_fn, scope=scope) 2433 2434 self._testScope(factory, use_outer_scope=True) 2435 self._testScope(factory, use_outer_scope=False) 2436 self._testScope(factory, prefix=None, use_outer_scope=False) 2437 2438 2439class DeviceWrapperCell(rnn_cell.RNNCell): 2440 """Class to ensure cell calculation happens on a specific device.""" 2441 2442 def __init__(self, cell, device): 2443 self._cell = cell 2444 self._device = device 2445 2446 @property 2447 def output_size(self): 2448 return self._cell.output_size 2449 2450 @property 2451 def state_size(self): 2452 return self._cell.state_size 2453 2454 def __call__(self, input_, state, scope=None): 2455 if self._device is not None: 2456 with ops.device(self._device): 2457 return self._cell(input_, state, scope=scope) 2458 else: 2459 return self._cell(input_, state, scope=scope) 2460 2461 2462class TensorArrayOnCorrectDeviceTest(test.TestCase): 2463 2464 def _execute_rnn_on(self, 2465 rnn_device=None, 2466 cell_device=None, 2467 input_device=None): 2468 batch_size = 3 2469 time_steps = 7 2470 input_size = 5 2471 num_units = 10 2472 2473 cell = rnn_cell.LSTMCell(num_units, use_peepholes=True) 2474 gpu_cell = DeviceWrapperCell(cell, cell_device) 2475 inputs = np.random.randn(batch_size, time_steps, input_size).astype( 2476 np.float32) 2477 sequence_length = np.random.randint(0, time_steps, size=batch_size) 2478 2479 if input_device is not None: 2480 with ops.device(input_device): 2481 inputs = constant_op.constant(inputs) 2482 2483 if rnn_device is not None: 2484 with ops.device(rnn_device): 2485 outputs, _ = rnn.dynamic_rnn( 2486 gpu_cell, 2487 inputs, 2488 sequence_length=sequence_length, 2489 dtype=dtypes.float32) 2490 else: 2491 outputs, _ = rnn.dynamic_rnn( 2492 gpu_cell, 2493 inputs, 2494 sequence_length=sequence_length, 2495 dtype=dtypes.float32) 2496 2497 with self.session() as sess: 2498 opts = config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE) 2499 run_metadata = config_pb2.RunMetadata() 2500 variables_lib.global_variables_initializer().run() 2501 sess.run(outputs, options=opts, run_metadata=run_metadata) 2502 2503 return run_metadata 2504 2505 def _retrieve_cpu_gpu_stats(self, run_metadata): 2506 cpu_stats = None 2507 gpu_stats = None 2508 step_stats = run_metadata.step_stats 2509 for ds in step_stats.dev_stats: 2510 if "cpu:0" in ds.device[-5:].lower(): 2511 cpu_stats = ds.node_stats 2512 if "gpu:0" == ds.device[-5:].lower(): 2513 gpu_stats = ds.node_stats 2514 return cpu_stats, gpu_stats 2515 2516 @test_util.run_v1_only("b/124229375") 2517 def testRNNOnCPUCellOnGPU(self): 2518 if not test.is_gpu_available(): 2519 return # Test requires access to a GPU 2520 2521 gpu_dev = test.gpu_device_name() 2522 run_metadata = self._execute_rnn_on( 2523 rnn_device="/cpu:0", cell_device=gpu_dev) 2524 cpu_stats, gpu_stats = self._retrieve_cpu_gpu_stats(run_metadata) 2525 2526 def _assert_in(op_str, in_stats, out_stats): 2527 self.assertTrue(any(op_str in s.node_name for s in in_stats)) 2528 self.assertFalse(any(op_str in s.node_name for s in out_stats)) 2529 2530 # Writes happen at output of RNN cell 2531 _assert_in("TensorArrayWrite", gpu_stats, cpu_stats) 2532 # Gather happens on final TensorArray 2533 _assert_in("TensorArrayGather", gpu_stats, cpu_stats) 2534 # Reads happen at input to RNN cell 2535 _assert_in("TensorArrayRead", cpu_stats, gpu_stats) 2536 # Scatters happen to get initial input into TensorArray 2537 _assert_in("TensorArrayScatter", cpu_stats, gpu_stats) 2538 2539 @test_util.run_v1_only("b/124229375") 2540 def testRNNOnCPUCellOnCPU(self): 2541 if not test.is_gpu_available(): 2542 return # Test requires access to a GPU 2543 2544 gpu_dev = test.gpu_device_name() 2545 run_metadata = self._execute_rnn_on( 2546 rnn_device="/cpu:0", cell_device="/cpu:0", input_device=gpu_dev) 2547 cpu_stats, gpu_stats = self._retrieve_cpu_gpu_stats(run_metadata) 2548 2549 def _assert_in(op_str, in_stats, out_stats): 2550 self.assertTrue(any(op_str in s.node_name for s in in_stats)) 2551 self.assertFalse(any(op_str in s.node_name for s in out_stats)) 2552 2553 # All TensorArray operations happen on CPU 2554 _assert_in("TensorArray", cpu_stats, gpu_stats) 2555 2556 @test_util.run_v1_only("b/124229375") 2557 def testInputOnGPUCellNotDeclared(self): 2558 if not test.is_gpu_available(): 2559 return # Test requires access to a GPU 2560 2561 gpu_dev = test.gpu_device_name() 2562 run_metadata = self._execute_rnn_on(input_device=gpu_dev) 2563 cpu_stats, gpu_stats = self._retrieve_cpu_gpu_stats(run_metadata) 2564 2565 def _assert_in(op_str, in_stats, out_stats): 2566 self.assertTrue(any(op_str in s.node_name for s in in_stats)) 2567 self.assertFalse(any(op_str in s.node_name for s in out_stats)) 2568 2569 # Everything happens on GPU 2570 _assert_in("TensorArray", gpu_stats, cpu_stats) 2571 2572 2573class RNNCellTest(test.TestCase, parameterized.TestCase): 2574 2575 @test_util.run_v1_only("b/124229375") 2576 def testBasicRNNCell(self): 2577 with self.cached_session() as sess: 2578 with variable_scope.variable_scope( 2579 "root", initializer=init_ops.constant_initializer(0.5)): 2580 x = array_ops.zeros([1, 2]) 2581 m = array_ops.zeros([1, 2]) 2582 cell = rnn_cell_impl.BasicRNNCell(2) 2583 g, _ = cell(x, m) 2584 self.assertEqual([ 2585 "root/basic_rnn_cell/%s:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, 2586 "root/basic_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME 2587 ], [v.name for v in cell.trainable_variables]) 2588 self.assertFalse(cell.non_trainable_variables) 2589 sess.run([variables_lib.global_variables_initializer()]) 2590 res = sess.run([g], { 2591 x: np.array([[1., 1.]]), 2592 m: np.array([[0.1, 0.1]]) 2593 }) 2594 self.assertEqual(res[0].shape, (1, 2)) 2595 2596 @test_util.run_v1_only("b/124229375") 2597 def testBasicRNNCellNotTrainable(self): 2598 with self.cached_session() as sess: 2599 2600 def not_trainable_getter(getter, *args, **kwargs): 2601 kwargs["trainable"] = False 2602 return getter(*args, **kwargs) 2603 2604 with variable_scope.variable_scope( 2605 "root", 2606 initializer=init_ops.constant_initializer(0.5), 2607 custom_getter=not_trainable_getter): 2608 x = array_ops.zeros([1, 2]) 2609 m = array_ops.zeros([1, 2]) 2610 cell = rnn_cell_impl.BasicRNNCell(2) 2611 g, _ = cell(x, m) 2612 self.assertFalse(cell.trainable_variables) 2613 self.assertEqual([ 2614 "root/basic_rnn_cell/%s:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, 2615 "root/basic_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME 2616 ], [v.name for v in cell.non_trainable_variables]) 2617 sess.run([variables_lib.global_variables_initializer()]) 2618 res = sess.run([g], { 2619 x: np.array([[1., 1.]]), 2620 m: np.array([[0.1, 0.1]]) 2621 }) 2622 self.assertEqual(res[0].shape, (1, 2)) 2623 2624 @test_util.run_v1_only("b/124229375") 2625 def testGRUCell(self): 2626 with self.cached_session() as sess: 2627 with variable_scope.variable_scope( 2628 "root", initializer=init_ops.constant_initializer(0.5)): 2629 x = array_ops.zeros([1, 2]) 2630 m = array_ops.zeros([1, 2]) 2631 g, _ = rnn_cell_impl.GRUCell(2)(x, m) 2632 sess.run([variables_lib.global_variables_initializer()]) 2633 res = sess.run([g], { 2634 x: np.array([[1., 1.]]), 2635 m: np.array([[0.1, 0.1]]) 2636 }) 2637 # Smoke test 2638 self.assertAllClose(res[0], [[0.175991, 0.175991]]) 2639 with variable_scope.variable_scope( 2640 "other", initializer=init_ops.constant_initializer(0.5)): 2641 # Test GRUCell with input_size != num_units. 2642 x = array_ops.zeros([1, 3]) 2643 m = array_ops.zeros([1, 2]) 2644 g, _ = rnn_cell_impl.GRUCell(2)(x, m) 2645 sess.run([variables_lib.global_variables_initializer()]) 2646 res = sess.run([g], { 2647 x: np.array([[1., 1., 1.]]), 2648 m: np.array([[0.1, 0.1]]) 2649 }) 2650 # Smoke test 2651 self.assertAllClose(res[0], [[0.156736, 0.156736]]) 2652 2653 @test_util.run_v1_only("b/124229375") 2654 def testBasicLSTMCell(self): 2655 for dtype in [dtypes.float16, dtypes.float32]: 2656 np_dtype = dtype.as_numpy_dtype 2657 with self.session(graph=ops.Graph()) as sess: 2658 with variable_scope.variable_scope( 2659 "root", initializer=init_ops.constant_initializer(0.5)): 2660 x = array_ops.zeros([1, 2], dtype=dtype) 2661 m = array_ops.zeros([1, 8], dtype=dtype) 2662 cell = rnn_cell_impl.MultiRNNCell( 2663 [ 2664 rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False) 2665 for _ in range(2) 2666 ], 2667 state_is_tuple=False) 2668 self.assertEqual(cell.dtype, None) 2669 self.assertIn("cell-0", cell._trackable_children()) 2670 self.assertIn("cell-1", cell._trackable_children()) 2671 cell.get_config() # Should not throw an error 2672 g, out_m = cell(x, m) 2673 # Layer infers the input type. 2674 self.assertEqual(cell.dtype, dtype.name) 2675 expected_variable_names = [ 2676 "root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0" % 2677 rnn_cell_impl._WEIGHTS_VARIABLE_NAME, 2678 "root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0" % 2679 rnn_cell_impl._BIAS_VARIABLE_NAME, 2680 "root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" % 2681 rnn_cell_impl._WEIGHTS_VARIABLE_NAME, 2682 "root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" % 2683 rnn_cell_impl._BIAS_VARIABLE_NAME 2684 ] 2685 self.assertEqual(expected_variable_names, 2686 [v.name for v in cell.trainable_variables]) 2687 self.assertFalse(cell.non_trainable_variables) 2688 sess.run([variables_lib.global_variables_initializer()]) 2689 res = sess.run([g, out_m], { 2690 x: np.array([[1., 1.]]), 2691 m: 0.1 * np.ones([1, 8]) 2692 }) 2693 self.assertEqual(len(res), 2) 2694 variables = variables_lib.global_variables() 2695 self.assertEqual(expected_variable_names, [v.name for v in variables]) 2696 # The numbers in results were not calculated, this is just a 2697 # smoke test. 2698 self.assertAllClose(res[0], np.array( 2699 [[0.240, 0.240]], dtype=np_dtype), 1e-2) 2700 expected_mem = np.array( 2701 [[0.689, 0.689, 0.448, 0.448, 0.398, 0.398, 0.240, 0.240]], 2702 dtype=np_dtype) 2703 self.assertAllClose(res[1], expected_mem, 1e-2) 2704 with variable_scope.variable_scope( 2705 "other", initializer=init_ops.constant_initializer(0.5)): 2706 # Test BasicLSTMCell with input_size != num_units. 2707 x = array_ops.zeros([1, 3], dtype=dtype) 2708 m = array_ops.zeros([1, 4], dtype=dtype) 2709 g, out_m = rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)(x, m) 2710 sess.run([variables_lib.global_variables_initializer()]) 2711 res = sess.run( 2712 [g, out_m], { 2713 x: np.array([[1., 1., 1.]], dtype=np_dtype), 2714 m: 0.1 * np.ones([1, 4], dtype=np_dtype) 2715 }) 2716 self.assertEqual(len(res), 2) 2717 2718 @test_util.run_v1_only("b/124229375") 2719 def testBasicLSTMCellDimension0Error(self): 2720 """Tests that dimension 0 in both(x and m) shape must be equal.""" 2721 with self.cached_session() as sess: 2722 with variable_scope.variable_scope( 2723 "root", initializer=init_ops.constant_initializer(0.5)): 2724 num_units = 2 2725 state_size = num_units * 2 2726 batch_size = 3 2727 input_size = 4 2728 x = array_ops.zeros([batch_size, input_size]) 2729 m = array_ops.zeros([batch_size - 1, state_size]) 2730 with self.assertRaises(ValueError): 2731 g, out_m = rnn_cell_impl.BasicLSTMCell( 2732 num_units, state_is_tuple=False)(x, m) 2733 sess.run([variables_lib.global_variables_initializer()]) 2734 sess.run( 2735 [g, out_m], { 2736 x: 1 * np.ones([batch_size, input_size]), 2737 m: 0.1 * np.ones([batch_size - 1, state_size]) 2738 }) 2739 2740 def testBasicLSTMCellStateSizeError(self): 2741 """Tests that state_size must be num_units * 2.""" 2742 with self.cached_session() as sess: 2743 with variable_scope.variable_scope( 2744 "root", initializer=init_ops.constant_initializer(0.5)): 2745 num_units = 2 2746 state_size = num_units * 3 # state_size must be num_units * 2 2747 batch_size = 3 2748 input_size = 4 2749 x = array_ops.zeros([batch_size, input_size]) 2750 m = array_ops.zeros([batch_size, state_size]) 2751 with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)): 2752 g, out_m = rnn_cell_impl.BasicLSTMCell( 2753 num_units, state_is_tuple=False)(x, m) 2754 sess.run([variables_lib.global_variables_initializer()]) 2755 sess.run( 2756 [g, out_m], { 2757 x: 1 * np.ones([batch_size, input_size]), 2758 m: 0.1 * np.ones([batch_size, state_size]) 2759 }) 2760 2761 @test_util.run_v1_only("b/124229375") 2762 def testBasicLSTMCellStateTupleType(self): 2763 with self.cached_session(): 2764 with variable_scope.variable_scope( 2765 "root", initializer=init_ops.constant_initializer(0.5)): 2766 x = array_ops.zeros([1, 2]) 2767 m0 = (array_ops.zeros([1, 2]),) * 2 2768 m1 = (array_ops.zeros([1, 2]),) * 2 2769 cell = rnn_cell_impl.MultiRNNCell( 2770 [rnn_cell_impl.BasicLSTMCell(2) for _ in range(2)], 2771 state_is_tuple=True) 2772 self.assertTrue(isinstance(cell.state_size, tuple)) 2773 self.assertTrue( 2774 isinstance(cell.state_size[0], rnn_cell_impl.LSTMStateTuple)) 2775 self.assertTrue( 2776 isinstance(cell.state_size[1], rnn_cell_impl.LSTMStateTuple)) 2777 2778 # Pass in regular tuples 2779 _, (out_m0, out_m1) = cell(x, (m0, m1)) 2780 self.assertTrue(isinstance(out_m0, rnn_cell_impl.LSTMStateTuple)) 2781 self.assertTrue(isinstance(out_m1, rnn_cell_impl.LSTMStateTuple)) 2782 2783 # Pass in LSTMStateTuples 2784 variable_scope.get_variable_scope().reuse_variables() 2785 zero_state = cell.zero_state(1, dtypes.float32) 2786 self.assertTrue(isinstance(zero_state, tuple)) 2787 self.assertTrue(isinstance(zero_state[0], rnn_cell_impl.LSTMStateTuple)) 2788 self.assertTrue(isinstance(zero_state[1], rnn_cell_impl.LSTMStateTuple)) 2789 _, (out_m0, out_m1) = cell(x, zero_state) 2790 self.assertTrue(isinstance(out_m0, rnn_cell_impl.LSTMStateTuple)) 2791 self.assertTrue(isinstance(out_m1, rnn_cell_impl.LSTMStateTuple)) 2792 2793 @test_util.run_v1_only("b/124229375") 2794 def testBasicLSTMCellWithStateTuple(self): 2795 with self.cached_session() as sess: 2796 with variable_scope.variable_scope( 2797 "root", initializer=init_ops.constant_initializer(0.5)): 2798 x = array_ops.zeros([1, 2]) 2799 m0 = array_ops.zeros([1, 4]) 2800 m1 = array_ops.zeros([1, 4]) 2801 cell = rnn_cell_impl.MultiRNNCell( 2802 [ 2803 rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False) 2804 for _ in range(2) 2805 ], 2806 state_is_tuple=True) 2807 g, (out_m0, out_m1) = cell(x, (m0, m1)) 2808 sess.run([variables_lib.global_variables_initializer()]) 2809 res = sess.run( 2810 [g, out_m0, out_m1], { 2811 x: np.array([[1., 1.]]), 2812 m0: 0.1 * np.ones([1, 4]), 2813 m1: 0.1 * np.ones([1, 4]) 2814 }) 2815 self.assertEqual(len(res), 3) 2816 # The numbers in results were not calculated, this is just a smoke test. 2817 # Note, however, these values should match the original 2818 # version having state_is_tuple=False. 2819 self.assertAllClose(res[0], [[0.24024698, 0.24024698]]) 2820 expected_mem0 = np.array( 2821 [[0.68967271, 0.68967271, 0.44848421, 0.44848421]]) 2822 expected_mem1 = np.array( 2823 [[0.39897051, 0.39897051, 0.24024698, 0.24024698]]) 2824 self.assertAllClose(res[1], expected_mem0) 2825 self.assertAllClose(res[2], expected_mem1) 2826 2827 @test_util.run_v1_only("b/124229375") 2828 def testLSTMCell(self): 2829 with self.cached_session() as sess: 2830 num_units = 8 2831 num_proj = 6 2832 state_size = num_units + num_proj 2833 batch_size = 3 2834 input_size = 2 2835 with variable_scope.variable_scope( 2836 "root", initializer=init_ops.constant_initializer(0.5)): 2837 x = array_ops.zeros([batch_size, input_size]) 2838 m = array_ops.zeros([batch_size, state_size]) 2839 cell = rnn_cell_impl.LSTMCell( 2840 num_units=num_units, 2841 num_proj=num_proj, 2842 forget_bias=1.0, 2843 state_is_tuple=False) 2844 output, state = cell(x, m) 2845 sess.run([variables_lib.global_variables_initializer()]) 2846 res = sess.run( 2847 [output, state], { 2848 x: np.array([[1., 1.], [2., 2.], [3., 3.]]), 2849 m: 0.1 * np.ones((batch_size, state_size)) 2850 }) 2851 self.assertEqual(len(res), 2) 2852 # The numbers in results were not calculated, this is mostly just a 2853 # smoke test. 2854 self.assertEqual(res[0].shape, (batch_size, num_proj)) 2855 self.assertEqual(res[1].shape, (batch_size, state_size)) 2856 # Different inputs so different outputs and states 2857 for i in range(1, batch_size): 2858 self.assertTrue( 2859 float(np.linalg.norm((res[0][0, :] - res[0][i, :]))) > 1e-6) 2860 self.assertTrue( 2861 float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) > 1e-6) 2862 2863 @test_util.run_v1_only("b/124229375") 2864 def testLSTMCellVariables(self): 2865 with self.cached_session(): 2866 num_units = 8 2867 num_proj = 6 2868 state_size = num_units + num_proj 2869 batch_size = 3 2870 input_size = 2 2871 with variable_scope.variable_scope( 2872 "root", initializer=init_ops.constant_initializer(0.5)): 2873 x = array_ops.zeros([batch_size, input_size]) 2874 m = array_ops.zeros([batch_size, state_size]) 2875 cell = rnn_cell_impl.LSTMCell( 2876 num_units=num_units, 2877 num_proj=num_proj, 2878 forget_bias=1.0, 2879 state_is_tuple=False) 2880 cell(x, m) # Execute to create variables 2881 variables = variables_lib.global_variables() 2882 self.assertEqual(variables[0].op.name, "root/lstm_cell/kernel") 2883 self.assertEqual(variables[1].op.name, "root/lstm_cell/bias") 2884 self.assertEqual(variables[2].op.name, "root/lstm_cell/projection/kernel") 2885 2886 @test_util.run_in_graph_and_eager_modes 2887 def testWrapperCheckpointing(self): 2888 for wrapper_type in [ 2889 rnn_cell_impl.DropoutWrapper, 2890 rnn_cell_impl.ResidualWrapper, 2891 lambda cell: rnn_cell_impl.MultiRNNCell([cell])]: 2892 cell = rnn_cell_impl.BasicRNNCell(1) 2893 wrapper = wrapper_type(cell) 2894 wrapper(array_ops.ones([1, 1]), 2895 state=wrapper.zero_state(batch_size=1, dtype=dtypes.float32)) 2896 self.evaluate([v.initializer for v in cell.variables]) 2897 checkpoint = trackable_utils.Checkpoint(wrapper=wrapper) 2898 prefix = os.path.join(self.get_temp_dir(), "ckpt") 2899 self.evaluate(cell._bias.assign([40.])) 2900 save_path = checkpoint.save(prefix) 2901 self.evaluate(cell._bias.assign([0.])) 2902 checkpoint.restore(save_path).assert_consumed().run_restore_ops() 2903 self.assertAllEqual([40.], self.evaluate(cell._bias)) 2904 2905 @test_util.run_in_graph_and_eager_modes 2906 def testResidualWrapper(self): 2907 wrapper_type = rnn_cell_impl.ResidualWrapper 2908 x = ops.convert_to_tensor(np.array([[1., 1., 1.]])) 2909 m = ops.convert_to_tensor(np.array([[0.1, 0.1, 0.1]])) 2910 base_cell = rnn_cell_impl.GRUCell( 2911 3, kernel_initializer=init_ops.constant_initializer(0.5), 2912 bias_initializer=init_ops.constant_initializer(0.5)) 2913 g, m_new = base_cell(x, m) 2914 wrapper_object = wrapper_type(base_cell) 2915 wrapper_object.get_config() # Should not throw an error 2916 2917 self.assertIn("cell", wrapper_object._trackable_children()) 2918 self.assertIs(wrapper_object._trackable_children()["cell"], base_cell) 2919 2920 g_res, m_new_res = wrapper_object(x, m) 2921 self.evaluate([variables_lib.global_variables_initializer()]) 2922 res = self.evaluate([g, g_res, m_new, m_new_res]) 2923 # Residual connections 2924 self.assertAllClose(res[1], res[0] + [1., 1., 1.]) 2925 # States are left untouched 2926 self.assertAllClose(res[2], res[3]) 2927 2928 @test_util.run_in_graph_and_eager_modes 2929 def testResidualWrapperWithSlice(self): 2930 wrapper_type = rnn_cell_impl.ResidualWrapper 2931 x = ops.convert_to_tensor(np.array([[1., 1., 1., 1., 1.]])) 2932 m = ops.convert_to_tensor(np.array([[0.1, 0.1, 0.1]])) 2933 base_cell = rnn_cell_impl.GRUCell( 2934 3, kernel_initializer=init_ops.constant_initializer(0.5), 2935 bias_initializer=init_ops.constant_initializer(0.5)) 2936 g, m_new = base_cell(x, m) 2937 2938 def residual_with_slice_fn(inp, out): 2939 inp_sliced = array_ops.slice(inp, [0, 0], [-1, 3]) 2940 return inp_sliced + out 2941 2942 g_res, m_new_res = wrapper_type( 2943 base_cell, residual_with_slice_fn)(x, m) 2944 self.evaluate([variables_lib.global_variables_initializer()]) 2945 res_g, res_g_res, res_m_new, res_m_new_res = self.evaluate( 2946 [g, g_res, m_new, m_new_res]) 2947 # Residual connections 2948 self.assertAllClose(res_g_res, res_g + [1., 1., 1.]) 2949 # States are left untouched 2950 self.assertAllClose(res_m_new, res_m_new_res) 2951 2952 def testDeviceWrapper(self): 2953 wrapper_type = rnn_cell_impl.DeviceWrapper 2954 x = array_ops.zeros([1, 3]) 2955 m = array_ops.zeros([1, 3]) 2956 cell = rnn_cell_impl.GRUCell(3) 2957 wrapped_cell = wrapper_type(cell, "/cpu:0") 2958 wrapped_cell.get_config() # Should not throw an error 2959 self.assertEqual(wrapped_cell._trackable_children()["cell"], cell) 2960 2961 outputs, _ = wrapped_cell(x, m) 2962 self.assertIn("cpu:0", outputs.device.lower()) 2963 2964 def _retrieve_cpu_gpu_stats(self, run_metadata): 2965 cpu_stats = None 2966 gpu_stats = None 2967 step_stats = run_metadata.step_stats 2968 for ds in step_stats.dev_stats: 2969 if "cpu:0" in ds.device[-5:].lower(): 2970 cpu_stats = ds.node_stats 2971 if "gpu:0" == ds.device[-5:].lower(): 2972 gpu_stats = ds.node_stats 2973 return cpu_stats, gpu_stats 2974 2975 @test_util.run_v1_only("b/124229375") 2976 def testDeviceWrapperDynamicExecutionNodesAreAllProperlyLocated(self): 2977 if not test.is_gpu_available(): 2978 # Can't perform this test w/o a GPU 2979 return 2980 2981 gpu_dev = test.gpu_device_name() 2982 with self.session() as sess: 2983 with variable_scope.variable_scope( 2984 "root", initializer=init_ops.constant_initializer(0.5)): 2985 x = array_ops.zeros([1, 1, 3]) 2986 cell = rnn_cell_impl.DeviceWrapper(rnn_cell_impl.GRUCell(3), gpu_dev) 2987 with ops.device("/cpu:0"): 2988 outputs, _ = rnn.dynamic_rnn( 2989 cell=cell, inputs=x, dtype=dtypes.float32) 2990 run_metadata = config_pb2.RunMetadata() 2991 opts = config_pb2.RunOptions( 2992 trace_level=config_pb2.RunOptions.FULL_TRACE) 2993 2994 sess.run([variables_lib.global_variables_initializer()]) 2995 _ = sess.run(outputs, options=opts, run_metadata=run_metadata) 2996 2997 cpu_stats, gpu_stats = self._retrieve_cpu_gpu_stats(run_metadata) 2998 self.assertFalse([s for s in cpu_stats if "gru_cell" in s.node_name]) 2999 self.assertTrue([s for s in gpu_stats if "gru_cell" in s.node_name]) 3000 3001 @test_util.run_v1_only("b/124229375") 3002 def testMultiRNNCell(self): 3003 with self.cached_session() as sess: 3004 with variable_scope.variable_scope( 3005 "root", initializer=init_ops.constant_initializer(0.5)): 3006 x = array_ops.zeros([1, 2]) 3007 m = array_ops.zeros([1, 4]) 3008 multi_rnn_cell = rnn_cell_impl.MultiRNNCell( 3009 [rnn_cell_impl.GRUCell(2) for _ in range(2)], 3010 state_is_tuple=False) 3011 _, ml = multi_rnn_cell(x, m) 3012 sess.run([variables_lib.global_variables_initializer()]) 3013 res = sess.run(ml, { 3014 x: np.array([[1., 1.]]), 3015 m: np.array([[0.1, 0.1, 0.1, 0.1]]) 3016 }) 3017 # The numbers in results were not calculated, this is just a smoke test. 3018 self.assertAllClose(res, [[0.175991, 0.175991, 0.13248, 0.13248]]) 3019 self.assertEqual(len(multi_rnn_cell.weights), 2 * 4) 3020 self.assertTrue( 3021 [x.dtype == dtypes.float32 for x in multi_rnn_cell.weights]) 3022 3023 @test_util.run_v1_only("b/124229375") 3024 def testMultiRNNCellWithStateTuple(self): 3025 with self.cached_session() as sess: 3026 with variable_scope.variable_scope( 3027 "root", initializer=init_ops.constant_initializer(0.5)): 3028 x = array_ops.zeros([1, 2]) 3029 m_bad = array_ops.zeros([1, 4]) 3030 m_good = (array_ops.zeros([1, 2]), array_ops.zeros([1, 2])) 3031 3032 # Test incorrectness of state 3033 with self.assertRaisesRegex(ValueError, "Expected state .* a tuple"): 3034 rnn_cell_impl.MultiRNNCell( 3035 [rnn_cell_impl.GRUCell(2) for _ in range(2)], 3036 state_is_tuple=True)(x, m_bad) 3037 3038 _, ml = rnn_cell_impl.MultiRNNCell( 3039 [rnn_cell_impl.GRUCell(2) for _ in range(2)], 3040 state_is_tuple=True)(x, m_good) 3041 3042 sess.run([variables_lib.global_variables_initializer()]) 3043 res = sess.run( 3044 ml, { 3045 x: np.array([[1., 1.]]), 3046 m_good[0]: np.array([[0.1, 0.1]]), 3047 m_good[1]: np.array([[0.1, 0.1]]) 3048 }) 3049 3050 # The numbers in results were not calculated, this is just a 3051 # smoke test. However, these numbers should match those of 3052 # the test testMultiRNNCell. 3053 self.assertAllClose(res[0], [[0.175991, 0.175991]]) 3054 self.assertAllClose(res[1], [[0.13248, 0.13248]]) 3055 3056 def testDeviceWrapperSerialization(self): 3057 wrapper_cls = rnn_cell_impl.DeviceWrapper 3058 cell = rnn_cell_impl.LSTMCell(10) 3059 wrapper = wrapper_cls(cell, "/cpu:0") 3060 config = wrapper.get_config() 3061 3062 # Replace the cell in the config with real cell instance to work around the 3063 # reverse keras dependency issue. 3064 config_copy = config.copy() 3065 config_copy["cell"] = rnn_cell_impl.LSTMCell.from_config( 3066 config_copy["cell"]["config"]) 3067 reconstructed_wrapper = wrapper_cls.from_config(config_copy) 3068 self.assertDictEqual(config, reconstructed_wrapper.get_config()) 3069 self.assertIsInstance(reconstructed_wrapper, wrapper_cls) 3070 3071 def testResidualWrapperSerialization(self): 3072 wrapper_cls = rnn_cell_impl.ResidualWrapper 3073 cell = rnn_cell_impl.LSTMCell(10) 3074 wrapper = wrapper_cls(cell) 3075 config = wrapper.get_config() 3076 3077 # Replace the cell in the config with real cell instance to work around the 3078 # reverse keras dependency issue. 3079 config_copy = config.copy() 3080 config_copy["cell"] = rnn_cell_impl.LSTMCell.from_config( 3081 config_copy["cell"]["config"]) 3082 reconstructed_wrapper = wrapper_cls.from_config(config_copy) 3083 self.assertDictEqual(config, reconstructed_wrapper.get_config()) 3084 self.assertIsInstance(reconstructed_wrapper, wrapper_cls) 3085 3086 wrapper = wrapper_cls(cell, residual_fn=lambda i, o: i + i + o) 3087 config = wrapper.get_config() 3088 3089 config_copy = config.copy() 3090 config_copy["cell"] = rnn_cell_impl.LSTMCell.from_config( 3091 config_copy["cell"]["config"]) 3092 reconstructed_wrapper = wrapper_cls.from_config(config_copy) 3093 # Assert the reconstructed function will perform the math correctly. 3094 self.assertEqual(reconstructed_wrapper._residual_fn(1, 2), 4) 3095 3096 def residual_fn(inputs, outputs): 3097 return inputs * 3 + outputs 3098 3099 wrapper = wrapper_cls(cell, residual_fn=residual_fn) 3100 config = wrapper.get_config() 3101 3102 config_copy = config.copy() 3103 config_copy["cell"] = rnn_cell_impl.LSTMCell.from_config( 3104 config_copy["cell"]["config"]) 3105 reconstructed_wrapper = wrapper_cls.from_config(config_copy) 3106 # Assert the reconstructed function will perform the math correctly. 3107 self.assertEqual(reconstructed_wrapper._residual_fn(1, 2), 5) 3108 3109 def testDropoutWrapperSerialization(self): 3110 wrapper_cls = rnn_cell_impl.DropoutWrapper 3111 cell = rnn_cell_impl.LSTMCell(10) 3112 wrapper = wrapper_cls(cell) 3113 config = wrapper.get_config() 3114 3115 config_copy = config.copy() 3116 config_copy["cell"] = rnn_cell_impl.LSTMCell.from_config( 3117 config_copy["cell"]["config"]) 3118 reconstructed_wrapper = wrapper_cls.from_config(config_copy) 3119 self.assertDictEqual(config, reconstructed_wrapper.get_config()) 3120 self.assertIsInstance(reconstructed_wrapper, wrapper_cls) 3121 3122 wrapper = wrapper_cls(cell, dropout_state_filter_visitor=lambda s: True) 3123 config = wrapper.get_config() 3124 3125 config_copy = config.copy() 3126 config_copy["cell"] = rnn_cell_impl.LSTMCell.from_config( 3127 config_copy["cell"]["config"]) 3128 reconstructed_wrapper = wrapper_cls.from_config(config_copy) 3129 self.assertTrue(reconstructed_wrapper._dropout_state_filter(None)) 3130 3131 def dropout_state_filter_visitor(unused_state): 3132 return False 3133 3134 wrapper = wrapper_cls( 3135 cell, dropout_state_filter_visitor=dropout_state_filter_visitor) 3136 config = wrapper.get_config() 3137 3138 config_copy = config.copy() 3139 config_copy["cell"] = rnn_cell_impl.LSTMCell.from_config( 3140 config_copy["cell"]["config"]) 3141 reconstructed_wrapper = wrapper_cls.from_config(config_copy) 3142 self.assertFalse(reconstructed_wrapper._dropout_state_filter(None)) 3143 3144 def testSavedModel(self): 3145 if test_util.is_gpu_available(): 3146 self.skipTest("b/175887901") 3147 3148 with self.cached_session(): 3149 root = autotrackable.AutoTrackable() 3150 root.cell = rnn_cell_impl.LSTMCell(8) 3151 @def_function.function(input_signature=[tensor_spec.TensorSpec([3, 8])]) 3152 def call(x): 3153 state = root.cell.zero_state(3, dtype=x.dtype) 3154 y, _ = root.cell(x, state) 3155 return y 3156 root.call = call 3157 expected = root.call(array_ops.zeros((3, 8))) 3158 self.evaluate(variables_lib.global_variables_initializer()) 3159 3160 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 3161 save.save(root, save_dir) 3162 loaded = load.load(save_dir) 3163 self.evaluate(variables_lib.global_variables_initializer()) 3164 self.assertAllClose( 3165 expected, loaded.call(array_ops.zeros((3, 8)))) 3166 3167 3168@test_util.run_all_in_graph_and_eager_modes 3169@test_util.run_all_without_tensor_float_32( 3170 "Uses an LSTMCell, which calls matmul") 3171class DropoutWrapperTest(test.TestCase, parameterized.TestCase): 3172 3173 def _testDropoutWrapper(self, 3174 batch_size=None, 3175 time_steps=None, 3176 parallel_iterations=None, 3177 wrapper_type=None, 3178 scope="root", 3179 **kwargs): 3180 if batch_size is None and time_steps is None: 3181 # 2 time steps, batch size 1, depth 3 3182 batch_size = 1 3183 time_steps = 2 3184 x = constant_op.constant( 3185 [[[2., 2., 2.]], [[1., 1., 1.]]], dtype=dtypes.float32) 3186 m = rnn_cell_impl.LSTMStateTuple( 3187 *[constant_op.constant([[0.1, 0.1, 0.1]], dtype=dtypes.float32)] * 2) 3188 else: 3189 x = constant_op.constant( 3190 np.random.randn(time_steps, batch_size, 3).astype(np.float32)) 3191 m = rnn_cell_impl.LSTMStateTuple(*[ 3192 constant_op. 3193 constant([[0.1, 0.1, 0.1]] * batch_size, dtype=dtypes.float32)] * 2) 3194 outputs, final_state = rnn.dynamic_rnn( 3195 cell=wrapper_type( 3196 rnn_cell_impl.LSTMCell( 3197 3, initializer=init_ops.constant_initializer(0.5)), 3198 dtype=x.dtype, **kwargs), 3199 time_major=True, 3200 parallel_iterations=parallel_iterations, 3201 inputs=x, 3202 initial_state=m, 3203 scope=scope) 3204 self.evaluate([variables_lib.global_variables_initializer()]) 3205 res = self.evaluate([outputs, final_state]) 3206 self.assertEqual(res[0].shape, (time_steps, batch_size, 3)) 3207 self.assertEqual(res[1].c.shape, (batch_size, 3)) 3208 self.assertEqual(res[1].h.shape, (batch_size, 3)) 3209 return res 3210 3211 def testDropoutWrapperProperties(self): 3212 wrapper_type = rnn_cell_impl.DropoutWrapper 3213 cell = rnn_cell_impl.BasicRNNCell(10) 3214 wrapper = wrapper_type(cell) 3215 # Github issue 15810 3216 self.assertEqual(wrapper.wrapped_cell, cell) 3217 self.assertEqual(wrapper.state_size, 10) 3218 self.assertEqual(wrapper.output_size, 10) 3219 3220 def testDropoutWrapperZeroState(self): 3221 wrapper_type = rnn_cell_impl.DropoutWrapper 3222 3223 class _Cell(rnn_cell_impl.BasicRNNCell): 3224 3225 def zero_state(self, batch_size=None, dtype=None): 3226 return "wrapped_cell_zero_state" 3227 wrapper = wrapper_type(_Cell(10)) 3228 self.assertEqual(wrapper.zero_state(10, dtypes.float32), 3229 "wrapped_cell_zero_state") 3230 3231 def testDropoutWrapperKeepAllConstantInput(self): 3232 wrapper_type = rnn_cell_impl.DropoutWrapper 3233 keep = array_ops.ones([]) 3234 res = self._testDropoutWrapper( 3235 input_keep_prob=keep, output_keep_prob=keep, state_keep_prob=keep, 3236 wrapper_type=wrapper_type) 3237 true_full_output = np.array( 3238 [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], 3239 dtype=np.float32) 3240 true_full_final_c = np.array( 3241 [[1.949385, 1.949385, 1.949385]], dtype=np.float32) 3242 self.assertAllClose(true_full_output, res[0]) 3243 self.assertAllClose(true_full_output[1], res[1].h) 3244 self.assertAllClose(true_full_final_c, res[1].c) 3245 3246 def testDropoutWrapperKeepAll(self): 3247 wrapper_type = rnn_cell_impl.DropoutWrapper 3248 keep = variable_scope.get_variable("all", initializer=1.0) 3249 res = self._testDropoutWrapper( 3250 input_keep_prob=keep, output_keep_prob=keep, state_keep_prob=keep, 3251 wrapper_type=wrapper_type) 3252 true_full_output = np.array( 3253 [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], 3254 dtype=np.float32) 3255 true_full_final_c = np.array( 3256 [[1.949385, 1.949385, 1.949385]], dtype=np.float32) 3257 self.assertAllClose(true_full_output, res[0]) 3258 self.assertAllClose(true_full_output[1], res[1].h) 3259 self.assertAllClose(true_full_final_c, res[1].c) 3260 3261 def testDropoutWrapperWithSeed(self): 3262 wrapper_type = rnn_cell_impl.DropoutWrapper 3263 keep_some = 0.5 3264 random_seed.set_random_seed(2) 3265 ## Use parallel_iterations = 1 in both calls to 3266 ## _testDropoutWrapper to ensure the (per-time step) dropout is 3267 ## consistent across both calls. Otherwise the seed may not end 3268 ## up being munged consistently across both graphs. 3269 res_standard_1 = self._testDropoutWrapper( 3270 input_keep_prob=keep_some, 3271 output_keep_prob=keep_some, 3272 state_keep_prob=keep_some, 3273 seed=10, 3274 parallel_iterations=1, 3275 wrapper_type=wrapper_type, 3276 scope="root_1") 3277 random_seed.set_random_seed(2) 3278 res_standard_2 = self._testDropoutWrapper( 3279 input_keep_prob=keep_some, 3280 output_keep_prob=keep_some, 3281 state_keep_prob=keep_some, 3282 seed=10, 3283 parallel_iterations=1, 3284 wrapper_type=wrapper_type, 3285 scope="root_2") 3286 self.assertAllClose(res_standard_1[0], res_standard_2[0]) 3287 self.assertAllClose(res_standard_1[1].c, res_standard_2[1].c) 3288 self.assertAllClose(res_standard_1[1].h, res_standard_2[1].h) 3289 3290 def testDropoutWrapperKeepNoOutput(self): 3291 wrapper_type = rnn_cell_impl.DropoutWrapper 3292 keep_all = variable_scope.get_variable("all", initializer=1.0) 3293 keep_none = variable_scope.get_variable("none", initializer=1e-6) 3294 res = self._testDropoutWrapper( 3295 input_keep_prob=keep_all, 3296 output_keep_prob=keep_none, 3297 state_keep_prob=keep_all, 3298 wrapper_type=wrapper_type) 3299 true_full_output = np.array( 3300 [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], 3301 dtype=np.float32) 3302 true_full_final_c = np.array( 3303 [[1.949385, 1.949385, 1.949385]], dtype=np.float32) 3304 self.assertAllClose(np.zeros(res[0].shape), res[0]) 3305 self.assertAllClose(true_full_output[1], res[1].h) 3306 self.assertAllClose(true_full_final_c, res[1].c) 3307 3308 def testDropoutWrapperKeepNoStateExceptLSTMCellMemory(self): 3309 wrapper_type = rnn_cell_impl.DropoutWrapper 3310 keep_all = variable_scope.get_variable("all", initializer=1.0) 3311 keep_none = variable_scope.get_variable("none", initializer=1e-6) 3312 # Even though we dropout state, by default DropoutWrapper never 3313 # drops out the memory ("c") term of an LSTMStateTuple. 3314 res = self._testDropoutWrapper( 3315 input_keep_prob=keep_all, 3316 output_keep_prob=keep_all, 3317 state_keep_prob=keep_none, 3318 wrapper_type=wrapper_type) 3319 true_c_state = np.array([[1.713925, 1.713925, 1.713925]], dtype=np.float32) 3320 true_full_output = np.array( 3321 [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], 3322 dtype=np.float32) 3323 self.assertAllClose(true_full_output[0], res[0][0]) 3324 # Second output is modified by zero input state 3325 self.assertGreater(np.linalg.norm(true_full_output[1] - res[0][1]), 1e-4) 3326 # h state has been set to zero 3327 self.assertAllClose(np.zeros(res[1].h.shape), res[1].h) 3328 # c state of an LSTMStateTuple is NEVER modified. 3329 self.assertAllClose(true_c_state, res[1].c) 3330 3331 def testDropoutWrapperKeepNoInput(self): 3332 wrapper_type = rnn_cell_impl.DropoutWrapper 3333 keep_all = variable_scope.get_variable("all", initializer=1.0) 3334 keep_none = variable_scope.get_variable("none", initializer=1e-6) 3335 true_full_output = np.array( 3336 [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], 3337 dtype=np.float32) 3338 true_full_final_c = np.array( 3339 [[1.949385, 1.949385, 1.949385]], dtype=np.float32) 3340 # All outputs are different because inputs are zeroed out 3341 res = self._testDropoutWrapper( 3342 input_keep_prob=keep_none, 3343 output_keep_prob=keep_all, 3344 state_keep_prob=keep_all, 3345 wrapper_type=wrapper_type) 3346 self.assertGreater(np.linalg.norm(res[0] - true_full_output), 1e-4) 3347 self.assertGreater(np.linalg.norm(res[1].h - true_full_output[1]), 1e-4) 3348 self.assertGreater(np.linalg.norm(res[1].c - true_full_final_c), 1e-4) 3349 3350 def testDropoutWrapperRecurrentOutput(self): 3351 wrapper_type = rnn_cell_impl.DropoutWrapper 3352 keep_some = 0.8 3353 keep_all = variable_scope.get_variable("all", initializer=1.0) 3354 res = self._testDropoutWrapper( 3355 input_keep_prob=keep_all, 3356 output_keep_prob=keep_some, 3357 state_keep_prob=keep_all, 3358 variational_recurrent=True, 3359 wrapper_type=wrapper_type, 3360 input_size=3, 3361 batch_size=5, 3362 time_steps=7) 3363 # Ensure the same dropout pattern for all time steps 3364 output_mask = np.abs(res[0]) > 1e-6 3365 for m in output_mask[1:]: 3366 self.assertAllClose(output_mask[0], m) 3367 3368 def testDropoutWrapperRecurrentStateInputAndOutput(self): 3369 wrapper_type = rnn_cell_impl.DropoutWrapper 3370 keep_some = 0.9 3371 res = self._testDropoutWrapper( 3372 input_keep_prob=keep_some, 3373 output_keep_prob=keep_some, 3374 state_keep_prob=keep_some, 3375 variational_recurrent=True, 3376 wrapper_type=wrapper_type, 3377 input_size=3, 3378 batch_size=5, 3379 time_steps=7) 3380 3381 # Smoke test for the state/input masks. 3382 output_mask = np.abs(res[0]) > 1e-6 3383 for time_step in output_mask: 3384 # Ensure the same dropout output pattern for all time steps 3385 self.assertAllClose(output_mask[0], time_step) 3386 for batch_entry in time_step: 3387 # Assert all batch entries get the same mask 3388 self.assertAllClose(batch_entry, time_step[0]) 3389 3390 # For state, ensure all batch entries have the same mask 3391 state_c_mask = np.abs(res[1].c) > 1e-6 3392 state_h_mask = np.abs(res[1].h) > 1e-6 3393 for batch_entry in state_c_mask: 3394 self.assertAllClose(batch_entry, state_c_mask[0]) 3395 for batch_entry in state_h_mask: 3396 self.assertAllClose(batch_entry, state_h_mask[0]) 3397 3398 def testDropoutWrapperRecurrentStateInputAndOutputWithSeed(self): 3399 wrapper_type = rnn_cell_impl.DropoutWrapper 3400 keep_some = 0.9 3401 random_seed.set_random_seed(2347) 3402 np.random.seed(23487) 3403 res0 = self._testDropoutWrapper( 3404 input_keep_prob=keep_some, 3405 output_keep_prob=keep_some, 3406 state_keep_prob=keep_some, 3407 variational_recurrent=True, 3408 wrapper_type=wrapper_type, 3409 input_size=3, 3410 batch_size=5, 3411 time_steps=7, 3412 seed=-234987, 3413 scope="root_0") 3414 random_seed.set_random_seed(2347) 3415 np.random.seed(23487) 3416 res1 = self._testDropoutWrapper( 3417 input_keep_prob=keep_some, 3418 output_keep_prob=keep_some, 3419 state_keep_prob=keep_some, 3420 variational_recurrent=True, 3421 wrapper_type=wrapper_type, 3422 input_size=3, 3423 batch_size=5, 3424 time_steps=7, 3425 seed=-234987, 3426 scope="root_1") 3427 3428 output_mask = np.abs(res0[0]) > 1e-6 3429 for time_step in output_mask: 3430 # Ensure the same dropout output pattern for all time steps 3431 self.assertAllClose(output_mask[0], time_step) 3432 for batch_entry in time_step: 3433 # Assert all batch entries get the same mask 3434 self.assertAllClose(batch_entry, time_step[0]) 3435 3436 # For state, ensure all batch entries have the same mask 3437 state_c_mask = np.abs(res0[1].c) > 1e-6 3438 state_h_mask = np.abs(res0[1].h) > 1e-6 3439 for batch_entry in state_c_mask: 3440 self.assertAllClose(batch_entry, state_c_mask[0]) 3441 for batch_entry in state_h_mask: 3442 self.assertAllClose(batch_entry, state_h_mask[0]) 3443 3444 # Ensure seeded calculation is identical. 3445 self.assertAllClose(res0[0], res1[0]) 3446 self.assertAllClose(res0[1].c, res1[1].c) 3447 self.assertAllClose(res0[1].h, res1[1].h) 3448 3449 3450if __name__ == "__main__": 3451 test.main() 3452