xref: /aosp_15_r20/external/tensorflow/tensorflow/python/data/kernel_tests/random_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for `tf.data.Dataset.random()`."""
16from absl.testing import parameterized
17
18from tensorflow.python.data.kernel_tests import test_base
19from tensorflow.python.data.ops import dataset_ops
20from tensorflow.python.framework import combinations
21from tensorflow.python.framework import random_seed
22from tensorflow.python.platform import test
23
24
25class RandomTest(test_base.DatasetTestBase, parameterized.TestCase):
26
27  @combinations.generate(
28      combinations.times(
29          test_base.default_test_combinations(),
30          combinations.combine(global_seed=[None, 10], local_seed=[None, 20])))
31  def testDeterminism(self, global_seed, local_seed):
32    expect_determinism = (global_seed is not None) or (local_seed is not None)
33
34    random_seed.set_random_seed(global_seed)
35    ds = dataset_ops.Dataset.random(seed=local_seed).take(10)
36
37    output_1 = self.getDatasetOutput(ds)
38    ds = self.graphRoundTrip(ds)
39    output_2 = self.getDatasetOutput(ds)
40
41    if expect_determinism:
42      self.assertEqual(output_1, output_2)
43    else:
44      # Technically not guaranteed since the two randomly-chosen int64 seeds
45      # could match, but that is sufficiently unlikely (1/2^128 with perfect
46      # random number generation).
47      self.assertNotEqual(output_1, output_2)
48
49  @combinations.generate(test_base.default_test_combinations())
50  def testName(self):
51    dataset = dataset_ops.Dataset.random(
52        seed=42, name="random").take(1).map(lambda _: 42)
53    self.assertDatasetProduces(dataset, [42])
54
55
56if __name__ == "__main__":
57  test.main()
58