1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for utilities working with arbitrarily nested structures.""" 16 17import functools 18 19from absl.testing import parameterized 20 21from tensorflow.python.data.kernel_tests import test_base 22from tensorflow.python.data.util import random_seed as data_random_seed 23from tensorflow.python.framework import combinations 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import random_seed 27from tensorflow.python.platform import test 28 29 30# NOTE(vikoth18): Arguments of parameterized tests are lifted into lambdas to make 31# sure they are not executed before the (eager- or graph-mode) test environment 32# has been set up. 33 34 35def _test_random_seed_combinations(): 36 37 cases = [ 38 # Each test case is a tuple with input to get_seed: 39 # (input_graph_seed, input_op_seed) 40 # and output from get_seed: 41 # (output_graph_seed, output_op_seed) 42 ( 43 "TestCase_0", 44 lambda: (None, None), 45 lambda: (0, 0), 46 ), 47 ("TestCase_1", lambda: (None, 1), lambda: 48 (random_seed.DEFAULT_GRAPH_SEED, 1)), 49 ("TestCase_2", lambda: (1, 1), lambda: (1, 1)), 50 ( 51 # Avoid nondeterministic (0, 0) output 52 "TestCase_3", 53 lambda: (0, 0), 54 lambda: (0, 2**31 - 1)), 55 ( 56 # Don't wrap to (0, 0) either 57 "TestCase_4", 58 lambda: (2**31 - 1, 0), 59 lambda: (0, 2**31 - 1)), 60 ( 61 # Wrapping for the other argument 62 "TestCase_5", 63 lambda: (0, 2**31 - 1), 64 lambda: (0, 2**31 - 1)), 65 ( 66 # Once more, with tensor-valued arguments 67 "TestCase_6", 68 lambda: 69 (None, constant_op.constant(1, dtype=dtypes.int64, name="one")), 70 lambda: (random_seed.DEFAULT_GRAPH_SEED, 1)), 71 ("TestCase_7", lambda: 72 (1, constant_op.constant(1, dtype=dtypes.int64, name="one")), lambda: 73 (1, 1)), 74 ( 75 "TestCase_8", 76 lambda: (0, constant_op.constant(0, dtype=dtypes.int64, name="zero")), 77 lambda: (0, 2**31 - 1) # Avoid nondeterministic (0, 0) output 78 ), 79 ( 80 "TestCase_9", 81 lambda: 82 (2**31 - 1, constant_op.constant(0, dtype=dtypes.int64, name="zero")), 83 lambda: (0, 2**31 - 1) # Don't wrap to (0, 0) either 84 ), 85 ( 86 "TestCase_10", 87 lambda: 88 (0, constant_op.constant( 89 2**31 - 1, dtype=dtypes.int64, name="intmax")), 90 lambda: (0, 2**31 - 1) # Wrapping for the other argument 91 ) 92 ] 93 94 def reduce_fn(x, y): 95 name, input_fn, output_fn = y 96 return x + combinations.combine( 97 input_fn=combinations.NamedObject("input_fn.{}".format(name), input_fn), 98 output_fn=combinations.NamedObject("output_fn.{}".format(name), 99 output_fn)) 100 101 return functools.reduce(reduce_fn, cases, []) 102 103 104class RandomSeedTest(test_base.DatasetTestBase, parameterized.TestCase): 105 106 def _checkEqual(self, tinput, toutput): 107 random_seed.set_random_seed(tinput[0]) 108 g_seed, op_seed = data_random_seed.get_seed(tinput[1]) 109 g_seed = self.evaluate(g_seed) 110 op_seed = self.evaluate(op_seed) 111 msg = "test_case = {0}, got {1}, want {2}".format(tinput, (g_seed, op_seed), 112 toutput) 113 self.assertEqual((g_seed, op_seed), toutput, msg=msg) 114 115 @combinations.generate( 116 combinations.times(test_base.default_test_combinations(), 117 _test_random_seed_combinations())) 118 def testRandomSeed(self, input_fn, output_fn): 119 tinput, toutput = input_fn(), output_fn() 120 self._checkEqual(tinput=tinput, toutput=toutput) 121 random_seed.set_random_seed(None) 122 123 @combinations.generate(test_base.graph_only_combinations()) 124 def testIncrementalRandomSeed(self): 125 random_seed.set_random_seed(1) 126 for i in range(10): 127 tinput = (1, None) 128 toutput = (1, i) 129 self._checkEqual(tinput=tinput, toutput=toutput) 130 131 132if __name__ == '__main__': 133 test.main() 134