xref: /aosp_15_r20/external/tensorflow/tensorflow/python/data/util/random_seed_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for utilities working with arbitrarily nested structures."""
16
17import functools
18
19from absl.testing import parameterized
20
21from tensorflow.python.data.kernel_tests import test_base
22from tensorflow.python.data.util import random_seed as data_random_seed
23from tensorflow.python.framework import combinations
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import random_seed
27from tensorflow.python.platform import test
28
29
30# NOTE(vikoth18): Arguments of parameterized tests are lifted into lambdas to make
31# sure they are not executed before the (eager- or graph-mode) test environment
32# has been set up.
33
34
35def _test_random_seed_combinations():
36
37  cases = [
38      # Each test case is a tuple with input to get_seed:
39      # (input_graph_seed, input_op_seed)
40      # and output from get_seed:
41      # (output_graph_seed, output_op_seed)
42      (
43          "TestCase_0",
44          lambda: (None, None),
45          lambda: (0, 0),
46      ),
47      ("TestCase_1", lambda: (None, 1), lambda:
48       (random_seed.DEFAULT_GRAPH_SEED, 1)),
49      ("TestCase_2", lambda: (1, 1), lambda: (1, 1)),
50      (
51          # Avoid nondeterministic (0, 0) output
52          "TestCase_3",
53          lambda: (0, 0),
54          lambda: (0, 2**31 - 1)),
55      (
56          # Don't wrap to (0, 0) either
57          "TestCase_4",
58          lambda: (2**31 - 1, 0),
59          lambda: (0, 2**31 - 1)),
60      (
61          # Wrapping for the other argument
62          "TestCase_5",
63          lambda: (0, 2**31 - 1),
64          lambda: (0, 2**31 - 1)),
65      (
66          # Once more, with tensor-valued arguments
67          "TestCase_6",
68          lambda:
69          (None, constant_op.constant(1, dtype=dtypes.int64, name="one")),
70          lambda: (random_seed.DEFAULT_GRAPH_SEED, 1)),
71      ("TestCase_7", lambda:
72       (1, constant_op.constant(1, dtype=dtypes.int64, name="one")), lambda:
73       (1, 1)),
74      (
75          "TestCase_8",
76          lambda: (0, constant_op.constant(0, dtype=dtypes.int64, name="zero")),
77          lambda: (0, 2**31 - 1)  # Avoid nondeterministic (0, 0) output
78      ),
79      (
80          "TestCase_9",
81          lambda:
82          (2**31 - 1, constant_op.constant(0, dtype=dtypes.int64, name="zero")),
83          lambda: (0, 2**31 - 1)  # Don't wrap to (0, 0) either
84      ),
85      (
86          "TestCase_10",
87          lambda:
88          (0, constant_op.constant(
89              2**31 - 1, dtype=dtypes.int64, name="intmax")),
90          lambda: (0, 2**31 - 1)  # Wrapping for the other argument
91      )
92  ]
93
94  def reduce_fn(x, y):
95    name, input_fn, output_fn = y
96    return x + combinations.combine(
97        input_fn=combinations.NamedObject("input_fn.{}".format(name), input_fn),
98        output_fn=combinations.NamedObject("output_fn.{}".format(name),
99                                           output_fn))
100
101  return functools.reduce(reduce_fn, cases, [])
102
103
104class RandomSeedTest(test_base.DatasetTestBase, parameterized.TestCase):
105
106  def _checkEqual(self, tinput, toutput):
107    random_seed.set_random_seed(tinput[0])
108    g_seed, op_seed = data_random_seed.get_seed(tinput[1])
109    g_seed = self.evaluate(g_seed)
110    op_seed = self.evaluate(op_seed)
111    msg = "test_case = {0}, got {1}, want {2}".format(tinput, (g_seed, op_seed),
112                                                      toutput)
113    self.assertEqual((g_seed, op_seed), toutput, msg=msg)
114
115  @combinations.generate(
116      combinations.times(test_base.default_test_combinations(),
117                         _test_random_seed_combinations()))
118  def testRandomSeed(self, input_fn, output_fn):
119    tinput, toutput = input_fn(), output_fn()
120    self._checkEqual(tinput=tinput, toutput=toutput)
121    random_seed.set_random_seed(None)
122
123  @combinations.generate(test_base.graph_only_combinations())
124  def testIncrementalRandomSeed(self):
125    random_seed.set_random_seed(1)
126    for i in range(10):
127      tinput = (1, None)
128      toutput = (1, i)
129      self._checkEqual(tinput=tinput, toutput=toutput)
130
131
132if __name__ == '__main__':
133  test.main()
134