xref: /aosp_15_r20/external/tensorflow/tensorflow/python/kernel_tests/random/stateless_random_ops_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for stateless random ops."""
16
17import functools
18
19from absl.testing import parameterized
20import numpy as np
21from tensorflow.python.compat import compat
22from tensorflow.python.eager import context
23from tensorflow.python.eager import def_function
24from tensorflow.python.framework import config
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import random_seed
29from tensorflow.python.framework import test_util
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import gen_stateless_random_ops_v2
32from tensorflow.python.ops import math_ops
33from tensorflow.python.ops import random_ops
34from tensorflow.python.ops import stateless_random_ops as stateless
35from tensorflow.python.platform import test
36
37
38# Note that in theory each test will reset the eager context and may choose to
39# hide some devices, so we shouldn't cache this transient info. Tests in this
40# file don't make those config changes, so caching is fine. It provides a good
41# speed-up.
42_cached_device = None
43
44
45def get_device():
46  global _cached_device
47  if _cached_device is not None:
48    return _cached_device
49  # Precedence from high to low
50  for device_type in ('XLA_GPU', 'GPU', 'XLA_CPU', 'CPU'):
51    devices = config.list_logical_devices(device_type)
52    if devices:
53      _cached_device = devices[0]
54      return _cached_device
55  raise ValueError('Cannot find any suitable device. Available devices: %s' %
56                   config.list_logical_devices())
57
58
59BEFORE_EXPIRE = (2020, 10, 24)
60AFTER_EXPIRE = (2020, 10, 26)
61
62
63def invert_philox(key, value):
64  """Invert the Philox bijection."""
65  key = np.array(key, dtype=np.uint32)
66  value = np.array(value, dtype=np.uint32)
67  step = np.array([0x9E3779B9, 0xBB67AE85], dtype=np.uint32)
68  for n in range(10)[::-1]:
69    key0, key1 = key + n * step
70    v0 = value[3] * 0x991a7cdb & 0xffffffff
71    v2 = value[1] * 0x6d7cae67 & 0xffffffff
72    hi0 = v0 * 0xD2511F53 >> 32
73    hi1 = v2 * 0xCD9E8D57 >> 32
74    v1 = hi1 ^ value[0] ^ key0
75    v3 = hi0 ^ value[2] ^ key1
76    value = v0, v1, v2, v3
77  return np.array(value)
78
79
80SEEDS = ((7, 17), (11, 5), (2, 3))
81SEED_TYPES = [dtypes.int32, dtypes.int64]
82
83
84def float_cases(shape_dtypes=(None,)):
85  cases = (
86      # Uniform distribution, with and without range
87      ('uniform', stateless.stateless_random_uniform, random_ops.random_uniform,
88       {}),
89      ('uniform2', stateless.stateless_random_uniform,
90       random_ops.random_uniform, dict(minval=2.2, maxval=7.1)),
91      # Normal distribution, with and without mean+stddev
92      ('normal', stateless.stateless_random_normal, random_ops.random_normal,
93       {}),
94      ('normal2', stateless.stateless_random_normal, random_ops.random_normal,
95       dict(mean=2, stddev=3)),
96      # Truncated normal distribution, with and without mean+stddev
97      ('trnorm', stateless.stateless_truncated_normal,
98       random_ops.truncated_normal, {}),
99      ('trnorm2', stateless.stateless_truncated_normal,
100       random_ops.truncated_normal, dict(mean=3, stddev=4)),
101  )
102  # Explicitly passing in params because capturing cell variable from loop is
103  # problematic in Python
104  def wrap(op, dtype, shape, shape_dtype, seed, **kwargs):
105    device_type = get_device().device_type
106    # Some dtypes are not supported on some devices
107    if (dtype == dtypes.float16 and device_type in ('XLA_GPU', 'XLA_CPU') or
108        dtype == dtypes.bfloat16 and device_type == 'GPU'):
109      dtype = dtypes.float32
110    shape_ = (constant_op.constant(shape, dtype=shape_dtype)
111              if shape_dtype is not None else shape)
112    return op(seed=seed, shape=shape_, dtype=dtype, **kwargs)
113
114  def _name(a):
115    if hasattr(a, 'name'):
116      return a.name
117    else:
118      return a
119
120  for dtype in dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64:
121    for shape_dtype in shape_dtypes:
122      for shape in (), (3,), (2, 5):
123        for name, stateless_op, stateful_op, kwargs in cases:
124          yield (('%s_%s_%s_%s' %
125                  (name, _name(dtype), shape, _name(shape_dtype))).replace(
126                      ' ', ''),
127                 functools.partial(wrap, stateless_op, dtype, shape,
128                                   shape_dtype, **kwargs),
129                 functools.partial(wrap, stateful_op, dtype, shape, shape_dtype,
130                                   **kwargs))
131
132
133def int_cases(shape_dtypes=(None,), minval_maxval=None):
134
135  def wrap(op, minval, maxval, shape, shape_dtype, dtype, seed, **kwargs):
136    shape_ = (constant_op.constant(shape, dtype=shape_dtype)
137              if shape_dtype is not None else shape)
138    return op(
139        seed=seed, shape=shape_, minval=minval, maxval=maxval, dtype=dtype,
140        **kwargs)
141
142  if minval_maxval is None:
143    minval_maxval = ((2, 11111),)
144  for minval, maxval in minval_maxval:
145    for shape_dtype in shape_dtypes:
146      for shape in (), (3,), (2, 5):
147        for dtype in dtypes.int32, dtypes.int64:
148          yield ('uniform_%s_%s' % (minval, maxval),
149                 functools.partial(wrap, stateless.stateless_random_uniform,
150                                   minval, maxval, shape, shape_dtype, dtype),
151                 functools.partial(wrap, random_ops.random_uniform, minval,
152                                   maxval, shape, shape_dtype, dtype))
153
154
155def multinomial_cases():
156  num_samples = 10
157  def wrap(op, logits, logits_dtype, output_dtype, seed):
158    return op(seed=seed,
159              logits=constant_op.constant(logits, dtype=logits_dtype),
160              num_samples=num_samples, output_dtype=output_dtype)
161  for logits_dtype in np.float16, np.float32, np.float64:
162    for output_dtype in dtypes.int32, dtypes.int64:
163      for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2],
164                                                [0.25, 0.75]]):
165        yield ('multinomial',
166               functools.partial(wrap, stateless.stateless_multinomial, logits,
167                                 logits_dtype, output_dtype),
168               functools.partial(wrap, random_ops.multinomial, logits,
169                                 logits_dtype, output_dtype))
170
171
172def gamma_cases():
173  def wrap(op, alpha, dtype, shape, seed):
174    return op(seed=seed, shape=shape,
175              alpha=constant_op.constant(alpha, dtype=dtype), dtype=dtype)
176  for dtype in np.float16, np.float32, np.float64:
177    for alpha in ([[.5, 1., 2.]], [[0.5, 0.5], [0.8, 0.2], [0.25, 0.75]]):
178      yield ('gamma',
179             functools.partial(wrap, stateless.stateless_random_gamma, alpha,
180                               dtype, (10,) + tuple(np.shape(alpha))),
181             functools.partial(wrap, random_ops.random_gamma, alpha, dtype,
182                               (10,)))
183
184
185def poisson_cases():
186  def wrap(op, lam, lam_dtype, out_dtype, shape, seed):
187    return op(seed=seed, shape=shape,
188              lam=constant_op.constant(lam_dtype(lam), dtype=lam_dtype),
189              dtype=out_dtype)
190  for lam_dtype in np.float16, np.float32, np.float64, np.int32, np.int64:
191    for out_dtype in np.float16, np.float32, np.float64, np.int32, np.int64:
192      for lam in ([[5.5, 1., 2.]], [[7.5, 10.5], [3.8, 8.2], [1.25, 9.75]]):
193        yield ('poisson',
194               functools.partial(wrap, stateless.stateless_random_poisson, lam,
195                                 lam_dtype, out_dtype,
196                                 (10,) + tuple(np.shape(lam))),
197               functools.partial(wrap, random_ops.random_poisson, lam,
198                                 lam_dtype, out_dtype, (10,)))
199
200
201def shuffle_cases():
202  for dtype in np.int32, np.int64, np.float32, np.float64:
203    # [], [0, ...] and [1, ...] are important corner cases
204    for shape in ([], [0], [1], [100], [0, 0], [1, 0], [0, 1], [1, 2], [5, 3],
205                  [7, 5, 3, 2]):
206      value = np.arange(np.prod(shape)).reshape(shape).astype(dtype)
207      yield ('shuffle',
208             functools.partial(stateless.stateless_shuffle, value),
209             functools.partial(random_ops.random_shuffle, value))
210
211
212@test_util.with_eager_op_as_function
213class StatelessOpsTest(test.TestCase, parameterized.TestCase):
214
215  def _test_match(self, case, seed):
216    # Stateless ops should be the same as stateful ops on the first call
217    # after seed scrambling.
218    key = 0x3ec8f720, 0x02461e29
219    preseed = invert_philox(key, (seed[0], 0, seed[1], 0)).astype(np.uint64)
220    preseed = preseed[::2] | preseed[1::2] << 32
221    with ops.device(get_device().name):
222      _, stateless_op, stateful_op = case
223      random_seed.set_random_seed(seed[0])
224      stateful = stateful_op(seed=seed[1])
225      pure = stateless_op(seed=preseed)
226      self.assertAllEqual(stateful, pure)
227
228  def _test_match_stateless_cpu_gpu(self, case, seed):
229    # Stateless ops should produce the same result on CPUs and GPUs.
230    _, stateless_op, _ = case
231
232    with ops.device('CPU'):
233      result_cpu = stateless_op(seed=seed)
234
235    with ops.device(get_device().name):
236      result_gpu = stateless_op(seed=seed)
237      self.assertAllClose(result_cpu, result_gpu)
238
239  def _test_old_and_new_stateless_match(self, case, seed):
240    """Tests that the new stateless ops match the old stateless ones."""
241    with ops.device(get_device().name):
242      _, stateless_op, _ = case
243      with compat.forward_compatibility_horizon(*BEFORE_EXPIRE):
244        old = stateless_op(seed=seed)
245      with compat.forward_compatibility_horizon(*AFTER_EXPIRE):
246        new = stateless_op(seed=seed)
247      self.assertAllClose(old, new)
248
249  def _test_explicit_alg(self, case, seed):
250    """Tests that alg=philox and alg=None are the same (on CPU/GPU)."""
251    with ops.device(get_device().name):
252      _, stateless_op, _ = case
253      implicit_alg = stateless_op(seed=seed)
254      # All device types allowed in this test will result in Philox
255      explicit_alg = stateless_op(seed=seed, alg='philox')
256      self.assertAllClose(implicit_alg, explicit_alg)
257
258  def _test_determinism(self, case, seed_type):
259    # Stateless values should be equal iff the seeds are equal (roughly)
260    seeds = [(x, y) for x in range(5) for y in range(5)] * 3  # pylint: disable=g-complex-comprehension
261    with self.test_session(), ops.device(get_device().name):
262      _, stateless_op, _ = case
263      if context.executing_eagerly():
264        values = [
265            (seed, stateless_op(seed=constant_op.constant(seed, seed_type)))
266            for seed in seeds]
267      else:
268        # Have this branch because the above branch is too slow in graph
269        # mode
270        seed_t = array_ops.placeholder(seed_type, shape=[2])
271        pure = stateless_op(seed=seed_t)
272        values = [
273            (seed, pure.eval(feed_dict={seed_t: seed})) for seed in seeds
274        ]
275      for s0, v0 in values:
276        for s1, v1 in values:
277          if dtypes.as_dtype(v0.dtype) != dtypes.bfloat16:
278            self.assertEqual(s0 == s1, np.all(v0 == v1))
279          elif s0 == s1:
280            # Skip the s0 != s1 case because v0 and v1 can be either equal or
281            # unequal in that case due to bfloat16's low precision
282            self.assertAllEqual(v0, v1)
283
284  @parameterized.named_parameters(
285      ('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed)  # pylint: disable=g-complex-comprehension
286      for seed_id, seed in enumerate(SEEDS)
287      for case_id, case in enumerate(float_cases()))
288  @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
289  def testMatchFloat(self, case, seed):
290    if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
291      # This test was passing before because soft placement silently picked the
292      # CPU kernels.
293      self.skipTest('Skip on XLA because XLA kernels do not support int64 '
294                    'seeds needed by this test.')
295    self._test_match(case, seed)
296
297  @parameterized.named_parameters(
298      ('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed)  # pylint: disable=g-complex-comprehension
299      for seed_id, seed in enumerate(SEEDS)
300      for case_id, case in enumerate(int_cases()))
301  @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
302  def testMatchInt(self, case, seed):
303    if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
304      # This test was passing before because soft placement silently picked the
305      # CPU kernels.
306      self.skipTest('Skip on XLA because XLA kernels do not support int64 '
307                    'seeds needed by this test.')
308    self._test_match(case, seed)
309
310  @parameterized.named_parameters(
311      ('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed)  # pylint: disable=g-complex-comprehension
312      for seed_id, seed in enumerate(SEEDS)
313      for case_id, case in enumerate(multinomial_cases()))
314  @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
315  def testMatchMultinomial(self, case, seed):
316    if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
317      # This test was passing before because soft placement silently picked the
318      # CPU kernels.
319      self.skipTest('Lacking XLA kernel')
320    self._test_match(case, seed)
321
322  @parameterized.named_parameters(
323      ('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed)  # pylint: disable=g-complex-comprehension
324      for seed_id, seed in enumerate(SEEDS)
325      for case_id, case in enumerate(gamma_cases()))
326  @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
327  def testMatchGamma(self, case, seed):
328    if get_device().device_type == 'GPU':
329      # This test was passing before because soft placement silently picked the
330      # CPU kernels.
331      self.skipTest('Lacking GPU kernel')
332    if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
333      # This test was passing before because soft placement silently picked the
334      # CPU kernels.
335      self.skipTest('Lacking XLA kernel')
336    self._test_match(case, seed)
337
338  @parameterized.named_parameters(
339      ('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed)  # pylint: disable=g-complex-comprehension
340      for seed_id, seed in enumerate(SEEDS)
341      for case_id, case in enumerate(gamma_cases()))
342  @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
343  def testStatelessGammaCpuGpuMatch(self, case, seed):
344    if get_device().device_type != 'GPU':
345      # This test compares the numbers produced by the CPU and GPU kernel for
346      # stateless_random_gamma.
347      self.skipTest('This test requires GPU')
348    self._test_match_stateless_cpu_gpu(case, seed)
349
350  @parameterized.named_parameters(
351      ('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed)  # pylint: disable=g-complex-comprehension
352      for seed_id, seed in enumerate(SEEDS)
353      for case_id, case in enumerate(poisson_cases()))
354  @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
355  def testMatchPoisson(self, case, seed):
356    if get_device().device_type == 'GPU':
357      # This test was passing before because soft placement silently picked the
358      # CPU kernels.
359      self.skipTest('Lacking GPU kernel')
360    if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
361      # This test was passing before because soft placement silently picked the
362      # CPU kernels.
363      self.skipTest('Lacking XLA kernel')
364    self._test_match(case, seed)
365
366  @parameterized.named_parameters(
367      ('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed)  # pylint: disable=g-complex-comprehension,undefined-variable
368      for seed_id, seed in enumerate(SEEDS)
369      for case_id, case in enumerate(shuffle_cases()))
370  def testMatchShuffle(self, case, seed):
371    if get_device().device_type == 'GPU':
372      self.skipTest('Lacking GPU kernel')
373    if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
374      self.skipTest('Lacking XLA kernel')
375    self._test_match(case, seed)
376
377  @parameterized.named_parameters(
378      ('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed)  # pylint: disable=g-complex-comprehension
379      for seed_id, seed in enumerate(SEEDS)
380      for case_id, case in enumerate(float_cases()))
381  @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
382  def testOldAndNewStatelessMatchFloat(self, case, seed):
383    self._test_old_and_new_stateless_match(case, seed)
384
385  @parameterized.named_parameters(
386      ('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed)  # pylint: disable=g-complex-comprehension
387      for seed_id, seed in enumerate(SEEDS)
388      for case_id, case in enumerate(
389          int_cases(minval_maxval=((2, 11111), (None, None)))))
390  @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
391  def testOldAndNewStatelessMatchInt(self, case, seed):
392    self._test_old_and_new_stateless_match(case, seed)
393
394  @parameterized.named_parameters(
395      ('_%s_%s' % (case[0], case_id), case)
396      for case_id, case in enumerate(float_cases()))
397  @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
398  def testExplicitAlgFloat(self, case):
399    seed = (7, 17)
400    self._test_explicit_alg(case, seed)
401
402  @parameterized.named_parameters(
403      ('_%s_%s' % (case[0], case_id), case)
404      for case_id, case in enumerate(
405          int_cases(minval_maxval=((2, 11111), (None, None)))))
406  @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
407  def testExplicitAlgInt(self, case):
408    seed = (7, 17)
409    self._test_explicit_alg(case, seed)
410
411  @parameterized.named_parameters(
412      ('_%s_%s_%s' % (case[0], seed_type.name, case_id), case, seed_type)  # pylint: disable=g-complex-comprehension
413      for seed_type in SEED_TYPES
414      for case_id, case in enumerate(
415          float_cases(shape_dtypes=(dtypes.int32, dtypes.int64))))
416  @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
417  def testDeterminismFloat(self, case, seed_type):
418    if seed_type == dtypes.int64 and get_device().device_type in ('XLA_GPU',
419                                                                  'XLA_CPU'):
420      # This test was passing before because soft placement silently picked the
421      # CPU kernels.
422      self.skipTest(
423          'Skip on XLA because XLA kernels do not support int64 seeds.')
424    self._test_determinism(case, seed_type)
425
426  @parameterized.named_parameters(
427      ('_%s_%s_%s' % (case[0], seed_type.name, case_id), case, seed_type)  # pylint: disable=g-complex-comprehension
428      for seed_type in SEED_TYPES
429      for case_id, case in enumerate(
430          int_cases(shape_dtypes=(dtypes.int32, dtypes.int64))))
431  @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
432  def testDeterminismInt(self, case, seed_type):
433    if seed_type == dtypes.int64 and get_device().device_type in ('XLA_GPU',
434                                                                  'XLA_CPU'):
435      # This test was passing before because soft placement silently picked the
436      # CPU kernels.
437      self.skipTest(
438          'Skip on XLA because XLA kernels do not support int64 seeds.')
439    self._test_determinism(case, seed_type)
440
441  @parameterized.named_parameters(
442      ('_%s_%s_%s' % (case[0], seed_type.name, case_id), case, seed_type)  # pylint: disable=g-complex-comprehension
443      for seed_type in SEED_TYPES
444      for case_id, case in enumerate(multinomial_cases()))
445  @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
446  def testDeterminismMultinomial(self, case, seed_type):
447    if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
448      # This test was passing before because soft placement silently picked the
449      # CPU kernels.
450      self.skipTest('Lacking XLA kernel')
451    self._test_determinism(case, seed_type)
452
453  @parameterized.named_parameters(
454      ('_%s_%s_%s' % (case[0], seed_type.name, case_id), case, seed_type)  # pylint: disable=g-complex-comprehension
455      for seed_type in SEED_TYPES
456      for case_id, case in enumerate(gamma_cases()))
457  @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
458  def testDeterminismGamma(self, case, seed_type):
459    if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
460      # This test was passing before because soft placement silently picked the
461      # CPU kernels.
462      self.skipTest('Lacking XLA kernel')
463    self._test_determinism(case, seed_type)
464
465  @parameterized.named_parameters(
466      ('_%s_%s_%s' % (case[0], seed_type.name, case_id), case, seed_type)  # pylint: disable=g-complex-comprehension
467      for seed_type in SEED_TYPES
468      for case_id, case in enumerate(poisson_cases()))
469  @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
470  def testDeterminismPoisson(self, case, seed_type):
471    if get_device().device_type == 'GPU':
472      # This test was passing before because soft placement silently picked the
473      # CPU kernels.
474      self.skipTest('Lacking GPU kernel')
475    if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
476      # This test was passing before because soft placement silently picked the
477      # CPU kernels.
478      self.skipTest('Lacking XLA kernel')
479    self._test_determinism(case, seed_type)
480
481  @test_util.run_v2_only
482  def testGetKeyCounterAlg(self):
483    seed = [1, 2]
484    key, counter = gen_stateless_random_ops_v2.stateless_random_get_key_counter(
485        seed)
486    self.assertAllEqual(key.shape, [1])
487    self.assertAllEqual(counter.shape, [2])
488    alg = gen_stateless_random_ops_v2.stateless_random_get_alg()
489    self.assertAllEqual(alg.shape, [])
490
491  def assertDTypeEqual(self, a, b):
492    self.assertEqual(dtypes.as_dtype(a), dtypes.as_dtype(b))
493
494  def assertNoEqualPair(self, ls):
495    for i in range(len(ls)):
496      for j in range(i + 1, len(ls)):
497        self.assertFalse(math_ops.reduce_all(ls[i] == ls[j]))
498
499  @parameterized.parameters(['int32', 'int64'])
500  @test_util.run_v2_only
501  def testSplit(self, dtype):
502    """Test for `split`."""
503    seed = constant_op.constant([1, 2], dtype=dtype)
504    new_seed = stateless.split(seed, 3)
505    self.assertEqual(new_seed.shape, [3, 2])
506    self.assertDTypeEqual(new_seed.dtype, dtype)
507    self.assertNoEqualPair([seed] + array_ops.unstack(new_seed))
508
509  @parameterized.parameters(['int32', 'int64'])
510  @test_util.run_v2_only
511  def testFoldIn(self, dtype):
512    """Test for `fold_in`."""
513    orig_seed = constant_op.constant([1, 2], dtype='int32')
514    seed = stateless.fold_in(orig_seed, constant_op.constant(3, dtype=dtype))
515    new_seeds = []
516    new_seeds.append(seed)
517    seed = stateless.fold_in(seed, constant_op.constant(4, dtype=dtype))
518    new_seeds.append(seed)
519    for s in new_seeds:
520      self.assertEqual(s.shape, [2])
521      self.assertDTypeEqual(s.dtype, dtype)
522    self.assertNoEqualPair([math_ops.cast(orig_seed, dtype)] + new_seeds)
523
524  @test_util.run_v2_only
525  def testErrors(self):
526    """Tests that proper errors are raised.
527    """
528    shape = [2, 3]
529    with self.assertRaisesWithPredicateMatch(
530        ValueError,
531        'minval must be a scalar; got a tensor of shape '):
532      @def_function.function
533      def f():
534        stateless.stateless_random_uniform(
535            shape=shape, seed=[1, 2], minval=array_ops.zeros(shape, 'int32'),
536            maxval=100, dtype='int32')
537      f()
538    with self.assertRaisesWithPredicateMatch(
539        ValueError,
540        'maxval must be a scalar; got a tensor of shape '):
541      @def_function.function
542      def f2():
543        stateless.stateless_random_uniform(
544            shape=shape, seed=[1, 2], minval=0,
545            maxval=array_ops.ones(shape, 'int32') * 100,
546            dtype='int32')
547      f2()
548
549
550if __name__ == '__main__':
551  config.set_soft_device_placement(False)
552  context.context().enable_xla_devices()
553  test.main()
554