1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for stateless random ops.""" 16 17import functools 18 19from absl.testing import parameterized 20import numpy as np 21from tensorflow.python.compat import compat 22from tensorflow.python.eager import context 23from tensorflow.python.eager import def_function 24from tensorflow.python.framework import config 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import random_seed 29from tensorflow.python.framework import test_util 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import gen_stateless_random_ops_v2 32from tensorflow.python.ops import math_ops 33from tensorflow.python.ops import random_ops 34from tensorflow.python.ops import stateless_random_ops as stateless 35from tensorflow.python.platform import test 36 37 38# Note that in theory each test will reset the eager context and may choose to 39# hide some devices, so we shouldn't cache this transient info. Tests in this 40# file don't make those config changes, so caching is fine. It provides a good 41# speed-up. 42_cached_device = None 43 44 45def get_device(): 46 global _cached_device 47 if _cached_device is not None: 48 return _cached_device 49 # Precedence from high to low 50 for device_type in ('XLA_GPU', 'GPU', 'XLA_CPU', 'CPU'): 51 devices = config.list_logical_devices(device_type) 52 if devices: 53 _cached_device = devices[0] 54 return _cached_device 55 raise ValueError('Cannot find any suitable device. Available devices: %s' % 56 config.list_logical_devices()) 57 58 59BEFORE_EXPIRE = (2020, 10, 24) 60AFTER_EXPIRE = (2020, 10, 26) 61 62 63def invert_philox(key, value): 64 """Invert the Philox bijection.""" 65 key = np.array(key, dtype=np.uint32) 66 value = np.array(value, dtype=np.uint32) 67 step = np.array([0x9E3779B9, 0xBB67AE85], dtype=np.uint32) 68 for n in range(10)[::-1]: 69 key0, key1 = key + n * step 70 v0 = value[3] * 0x991a7cdb & 0xffffffff 71 v2 = value[1] * 0x6d7cae67 & 0xffffffff 72 hi0 = v0 * 0xD2511F53 >> 32 73 hi1 = v2 * 0xCD9E8D57 >> 32 74 v1 = hi1 ^ value[0] ^ key0 75 v3 = hi0 ^ value[2] ^ key1 76 value = v0, v1, v2, v3 77 return np.array(value) 78 79 80SEEDS = ((7, 17), (11, 5), (2, 3)) 81SEED_TYPES = [dtypes.int32, dtypes.int64] 82 83 84def float_cases(shape_dtypes=(None,)): 85 cases = ( 86 # Uniform distribution, with and without range 87 ('uniform', stateless.stateless_random_uniform, random_ops.random_uniform, 88 {}), 89 ('uniform2', stateless.stateless_random_uniform, 90 random_ops.random_uniform, dict(minval=2.2, maxval=7.1)), 91 # Normal distribution, with and without mean+stddev 92 ('normal', stateless.stateless_random_normal, random_ops.random_normal, 93 {}), 94 ('normal2', stateless.stateless_random_normal, random_ops.random_normal, 95 dict(mean=2, stddev=3)), 96 # Truncated normal distribution, with and without mean+stddev 97 ('trnorm', stateless.stateless_truncated_normal, 98 random_ops.truncated_normal, {}), 99 ('trnorm2', stateless.stateless_truncated_normal, 100 random_ops.truncated_normal, dict(mean=3, stddev=4)), 101 ) 102 # Explicitly passing in params because capturing cell variable from loop is 103 # problematic in Python 104 def wrap(op, dtype, shape, shape_dtype, seed, **kwargs): 105 device_type = get_device().device_type 106 # Some dtypes are not supported on some devices 107 if (dtype == dtypes.float16 and device_type in ('XLA_GPU', 'XLA_CPU') or 108 dtype == dtypes.bfloat16 and device_type == 'GPU'): 109 dtype = dtypes.float32 110 shape_ = (constant_op.constant(shape, dtype=shape_dtype) 111 if shape_dtype is not None else shape) 112 return op(seed=seed, shape=shape_, dtype=dtype, **kwargs) 113 114 def _name(a): 115 if hasattr(a, 'name'): 116 return a.name 117 else: 118 return a 119 120 for dtype in dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64: 121 for shape_dtype in shape_dtypes: 122 for shape in (), (3,), (2, 5): 123 for name, stateless_op, stateful_op, kwargs in cases: 124 yield (('%s_%s_%s_%s' % 125 (name, _name(dtype), shape, _name(shape_dtype))).replace( 126 ' ', ''), 127 functools.partial(wrap, stateless_op, dtype, shape, 128 shape_dtype, **kwargs), 129 functools.partial(wrap, stateful_op, dtype, shape, shape_dtype, 130 **kwargs)) 131 132 133def int_cases(shape_dtypes=(None,), minval_maxval=None): 134 135 def wrap(op, minval, maxval, shape, shape_dtype, dtype, seed, **kwargs): 136 shape_ = (constant_op.constant(shape, dtype=shape_dtype) 137 if shape_dtype is not None else shape) 138 return op( 139 seed=seed, shape=shape_, minval=minval, maxval=maxval, dtype=dtype, 140 **kwargs) 141 142 if minval_maxval is None: 143 minval_maxval = ((2, 11111),) 144 for minval, maxval in minval_maxval: 145 for shape_dtype in shape_dtypes: 146 for shape in (), (3,), (2, 5): 147 for dtype in dtypes.int32, dtypes.int64: 148 yield ('uniform_%s_%s' % (minval, maxval), 149 functools.partial(wrap, stateless.stateless_random_uniform, 150 minval, maxval, shape, shape_dtype, dtype), 151 functools.partial(wrap, random_ops.random_uniform, minval, 152 maxval, shape, shape_dtype, dtype)) 153 154 155def multinomial_cases(): 156 num_samples = 10 157 def wrap(op, logits, logits_dtype, output_dtype, seed): 158 return op(seed=seed, 159 logits=constant_op.constant(logits, dtype=logits_dtype), 160 num_samples=num_samples, output_dtype=output_dtype) 161 for logits_dtype in np.float16, np.float32, np.float64: 162 for output_dtype in dtypes.int32, dtypes.int64: 163 for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2], 164 [0.25, 0.75]]): 165 yield ('multinomial', 166 functools.partial(wrap, stateless.stateless_multinomial, logits, 167 logits_dtype, output_dtype), 168 functools.partial(wrap, random_ops.multinomial, logits, 169 logits_dtype, output_dtype)) 170 171 172def gamma_cases(): 173 def wrap(op, alpha, dtype, shape, seed): 174 return op(seed=seed, shape=shape, 175 alpha=constant_op.constant(alpha, dtype=dtype), dtype=dtype) 176 for dtype in np.float16, np.float32, np.float64: 177 for alpha in ([[.5, 1., 2.]], [[0.5, 0.5], [0.8, 0.2], [0.25, 0.75]]): 178 yield ('gamma', 179 functools.partial(wrap, stateless.stateless_random_gamma, alpha, 180 dtype, (10,) + tuple(np.shape(alpha))), 181 functools.partial(wrap, random_ops.random_gamma, alpha, dtype, 182 (10,))) 183 184 185def poisson_cases(): 186 def wrap(op, lam, lam_dtype, out_dtype, shape, seed): 187 return op(seed=seed, shape=shape, 188 lam=constant_op.constant(lam_dtype(lam), dtype=lam_dtype), 189 dtype=out_dtype) 190 for lam_dtype in np.float16, np.float32, np.float64, np.int32, np.int64: 191 for out_dtype in np.float16, np.float32, np.float64, np.int32, np.int64: 192 for lam in ([[5.5, 1., 2.]], [[7.5, 10.5], [3.8, 8.2], [1.25, 9.75]]): 193 yield ('poisson', 194 functools.partial(wrap, stateless.stateless_random_poisson, lam, 195 lam_dtype, out_dtype, 196 (10,) + tuple(np.shape(lam))), 197 functools.partial(wrap, random_ops.random_poisson, lam, 198 lam_dtype, out_dtype, (10,))) 199 200 201def shuffle_cases(): 202 for dtype in np.int32, np.int64, np.float32, np.float64: 203 # [], [0, ...] and [1, ...] are important corner cases 204 for shape in ([], [0], [1], [100], [0, 0], [1, 0], [0, 1], [1, 2], [5, 3], 205 [7, 5, 3, 2]): 206 value = np.arange(np.prod(shape)).reshape(shape).astype(dtype) 207 yield ('shuffle', 208 functools.partial(stateless.stateless_shuffle, value), 209 functools.partial(random_ops.random_shuffle, value)) 210 211 212@test_util.with_eager_op_as_function 213class StatelessOpsTest(test.TestCase, parameterized.TestCase): 214 215 def _test_match(self, case, seed): 216 # Stateless ops should be the same as stateful ops on the first call 217 # after seed scrambling. 218 key = 0x3ec8f720, 0x02461e29 219 preseed = invert_philox(key, (seed[0], 0, seed[1], 0)).astype(np.uint64) 220 preseed = preseed[::2] | preseed[1::2] << 32 221 with ops.device(get_device().name): 222 _, stateless_op, stateful_op = case 223 random_seed.set_random_seed(seed[0]) 224 stateful = stateful_op(seed=seed[1]) 225 pure = stateless_op(seed=preseed) 226 self.assertAllEqual(stateful, pure) 227 228 def _test_match_stateless_cpu_gpu(self, case, seed): 229 # Stateless ops should produce the same result on CPUs and GPUs. 230 _, stateless_op, _ = case 231 232 with ops.device('CPU'): 233 result_cpu = stateless_op(seed=seed) 234 235 with ops.device(get_device().name): 236 result_gpu = stateless_op(seed=seed) 237 self.assertAllClose(result_cpu, result_gpu) 238 239 def _test_old_and_new_stateless_match(self, case, seed): 240 """Tests that the new stateless ops match the old stateless ones.""" 241 with ops.device(get_device().name): 242 _, stateless_op, _ = case 243 with compat.forward_compatibility_horizon(*BEFORE_EXPIRE): 244 old = stateless_op(seed=seed) 245 with compat.forward_compatibility_horizon(*AFTER_EXPIRE): 246 new = stateless_op(seed=seed) 247 self.assertAllClose(old, new) 248 249 def _test_explicit_alg(self, case, seed): 250 """Tests that alg=philox and alg=None are the same (on CPU/GPU).""" 251 with ops.device(get_device().name): 252 _, stateless_op, _ = case 253 implicit_alg = stateless_op(seed=seed) 254 # All device types allowed in this test will result in Philox 255 explicit_alg = stateless_op(seed=seed, alg='philox') 256 self.assertAllClose(implicit_alg, explicit_alg) 257 258 def _test_determinism(self, case, seed_type): 259 # Stateless values should be equal iff the seeds are equal (roughly) 260 seeds = [(x, y) for x in range(5) for y in range(5)] * 3 # pylint: disable=g-complex-comprehension 261 with self.test_session(), ops.device(get_device().name): 262 _, stateless_op, _ = case 263 if context.executing_eagerly(): 264 values = [ 265 (seed, stateless_op(seed=constant_op.constant(seed, seed_type))) 266 for seed in seeds] 267 else: 268 # Have this branch because the above branch is too slow in graph 269 # mode 270 seed_t = array_ops.placeholder(seed_type, shape=[2]) 271 pure = stateless_op(seed=seed_t) 272 values = [ 273 (seed, pure.eval(feed_dict={seed_t: seed})) for seed in seeds 274 ] 275 for s0, v0 in values: 276 for s1, v1 in values: 277 if dtypes.as_dtype(v0.dtype) != dtypes.bfloat16: 278 self.assertEqual(s0 == s1, np.all(v0 == v1)) 279 elif s0 == s1: 280 # Skip the s0 != s1 case because v0 and v1 can be either equal or 281 # unequal in that case due to bfloat16's low precision 282 self.assertAllEqual(v0, v1) 283 284 @parameterized.named_parameters( 285 ('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension 286 for seed_id, seed in enumerate(SEEDS) 287 for case_id, case in enumerate(float_cases())) 288 @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396') 289 def testMatchFloat(self, case, seed): 290 if get_device().device_type in ('XLA_GPU', 'XLA_CPU'): 291 # This test was passing before because soft placement silently picked the 292 # CPU kernels. 293 self.skipTest('Skip on XLA because XLA kernels do not support int64 ' 294 'seeds needed by this test.') 295 self._test_match(case, seed) 296 297 @parameterized.named_parameters( 298 ('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension 299 for seed_id, seed in enumerate(SEEDS) 300 for case_id, case in enumerate(int_cases())) 301 @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396') 302 def testMatchInt(self, case, seed): 303 if get_device().device_type in ('XLA_GPU', 'XLA_CPU'): 304 # This test was passing before because soft placement silently picked the 305 # CPU kernels. 306 self.skipTest('Skip on XLA because XLA kernels do not support int64 ' 307 'seeds needed by this test.') 308 self._test_match(case, seed) 309 310 @parameterized.named_parameters( 311 ('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension 312 for seed_id, seed in enumerate(SEEDS) 313 for case_id, case in enumerate(multinomial_cases())) 314 @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396') 315 def testMatchMultinomial(self, case, seed): 316 if get_device().device_type in ('XLA_GPU', 'XLA_CPU'): 317 # This test was passing before because soft placement silently picked the 318 # CPU kernels. 319 self.skipTest('Lacking XLA kernel') 320 self._test_match(case, seed) 321 322 @parameterized.named_parameters( 323 ('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension 324 for seed_id, seed in enumerate(SEEDS) 325 for case_id, case in enumerate(gamma_cases())) 326 @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396') 327 def testMatchGamma(self, case, seed): 328 if get_device().device_type == 'GPU': 329 # This test was passing before because soft placement silently picked the 330 # CPU kernels. 331 self.skipTest('Lacking GPU kernel') 332 if get_device().device_type in ('XLA_GPU', 'XLA_CPU'): 333 # This test was passing before because soft placement silently picked the 334 # CPU kernels. 335 self.skipTest('Lacking XLA kernel') 336 self._test_match(case, seed) 337 338 @parameterized.named_parameters( 339 ('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension 340 for seed_id, seed in enumerate(SEEDS) 341 for case_id, case in enumerate(gamma_cases())) 342 @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396') 343 def testStatelessGammaCpuGpuMatch(self, case, seed): 344 if get_device().device_type != 'GPU': 345 # This test compares the numbers produced by the CPU and GPU kernel for 346 # stateless_random_gamma. 347 self.skipTest('This test requires GPU') 348 self._test_match_stateless_cpu_gpu(case, seed) 349 350 @parameterized.named_parameters( 351 ('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension 352 for seed_id, seed in enumerate(SEEDS) 353 for case_id, case in enumerate(poisson_cases())) 354 @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396') 355 def testMatchPoisson(self, case, seed): 356 if get_device().device_type == 'GPU': 357 # This test was passing before because soft placement silently picked the 358 # CPU kernels. 359 self.skipTest('Lacking GPU kernel') 360 if get_device().device_type in ('XLA_GPU', 'XLA_CPU'): 361 # This test was passing before because soft placement silently picked the 362 # CPU kernels. 363 self.skipTest('Lacking XLA kernel') 364 self._test_match(case, seed) 365 366 @parameterized.named_parameters( 367 ('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension,undefined-variable 368 for seed_id, seed in enumerate(SEEDS) 369 for case_id, case in enumerate(shuffle_cases())) 370 def testMatchShuffle(self, case, seed): 371 if get_device().device_type == 'GPU': 372 self.skipTest('Lacking GPU kernel') 373 if get_device().device_type in ('XLA_GPU', 'XLA_CPU'): 374 self.skipTest('Lacking XLA kernel') 375 self._test_match(case, seed) 376 377 @parameterized.named_parameters( 378 ('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension 379 for seed_id, seed in enumerate(SEEDS) 380 for case_id, case in enumerate(float_cases())) 381 @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396') 382 def testOldAndNewStatelessMatchFloat(self, case, seed): 383 self._test_old_and_new_stateless_match(case, seed) 384 385 @parameterized.named_parameters( 386 ('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension 387 for seed_id, seed in enumerate(SEEDS) 388 for case_id, case in enumerate( 389 int_cases(minval_maxval=((2, 11111), (None, None))))) 390 @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396') 391 def testOldAndNewStatelessMatchInt(self, case, seed): 392 self._test_old_and_new_stateless_match(case, seed) 393 394 @parameterized.named_parameters( 395 ('_%s_%s' % (case[0], case_id), case) 396 for case_id, case in enumerate(float_cases())) 397 @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396') 398 def testExplicitAlgFloat(self, case): 399 seed = (7, 17) 400 self._test_explicit_alg(case, seed) 401 402 @parameterized.named_parameters( 403 ('_%s_%s' % (case[0], case_id), case) 404 for case_id, case in enumerate( 405 int_cases(minval_maxval=((2, 11111), (None, None))))) 406 @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396') 407 def testExplicitAlgInt(self, case): 408 seed = (7, 17) 409 self._test_explicit_alg(case, seed) 410 411 @parameterized.named_parameters( 412 ('_%s_%s_%s' % (case[0], seed_type.name, case_id), case, seed_type) # pylint: disable=g-complex-comprehension 413 for seed_type in SEED_TYPES 414 for case_id, case in enumerate( 415 float_cases(shape_dtypes=(dtypes.int32, dtypes.int64)))) 416 @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396') 417 def testDeterminismFloat(self, case, seed_type): 418 if seed_type == dtypes.int64 and get_device().device_type in ('XLA_GPU', 419 'XLA_CPU'): 420 # This test was passing before because soft placement silently picked the 421 # CPU kernels. 422 self.skipTest( 423 'Skip on XLA because XLA kernels do not support int64 seeds.') 424 self._test_determinism(case, seed_type) 425 426 @parameterized.named_parameters( 427 ('_%s_%s_%s' % (case[0], seed_type.name, case_id), case, seed_type) # pylint: disable=g-complex-comprehension 428 for seed_type in SEED_TYPES 429 for case_id, case in enumerate( 430 int_cases(shape_dtypes=(dtypes.int32, dtypes.int64)))) 431 @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396') 432 def testDeterminismInt(self, case, seed_type): 433 if seed_type == dtypes.int64 and get_device().device_type in ('XLA_GPU', 434 'XLA_CPU'): 435 # This test was passing before because soft placement silently picked the 436 # CPU kernels. 437 self.skipTest( 438 'Skip on XLA because XLA kernels do not support int64 seeds.') 439 self._test_determinism(case, seed_type) 440 441 @parameterized.named_parameters( 442 ('_%s_%s_%s' % (case[0], seed_type.name, case_id), case, seed_type) # pylint: disable=g-complex-comprehension 443 for seed_type in SEED_TYPES 444 for case_id, case in enumerate(multinomial_cases())) 445 @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396') 446 def testDeterminismMultinomial(self, case, seed_type): 447 if get_device().device_type in ('XLA_GPU', 'XLA_CPU'): 448 # This test was passing before because soft placement silently picked the 449 # CPU kernels. 450 self.skipTest('Lacking XLA kernel') 451 self._test_determinism(case, seed_type) 452 453 @parameterized.named_parameters( 454 ('_%s_%s_%s' % (case[0], seed_type.name, case_id), case, seed_type) # pylint: disable=g-complex-comprehension 455 for seed_type in SEED_TYPES 456 for case_id, case in enumerate(gamma_cases())) 457 @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396') 458 def testDeterminismGamma(self, case, seed_type): 459 if get_device().device_type in ('XLA_GPU', 'XLA_CPU'): 460 # This test was passing before because soft placement silently picked the 461 # CPU kernels. 462 self.skipTest('Lacking XLA kernel') 463 self._test_determinism(case, seed_type) 464 465 @parameterized.named_parameters( 466 ('_%s_%s_%s' % (case[0], seed_type.name, case_id), case, seed_type) # pylint: disable=g-complex-comprehension 467 for seed_type in SEED_TYPES 468 for case_id, case in enumerate(poisson_cases())) 469 @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396') 470 def testDeterminismPoisson(self, case, seed_type): 471 if get_device().device_type == 'GPU': 472 # This test was passing before because soft placement silently picked the 473 # CPU kernels. 474 self.skipTest('Lacking GPU kernel') 475 if get_device().device_type in ('XLA_GPU', 'XLA_CPU'): 476 # This test was passing before because soft placement silently picked the 477 # CPU kernels. 478 self.skipTest('Lacking XLA kernel') 479 self._test_determinism(case, seed_type) 480 481 @test_util.run_v2_only 482 def testGetKeyCounterAlg(self): 483 seed = [1, 2] 484 key, counter = gen_stateless_random_ops_v2.stateless_random_get_key_counter( 485 seed) 486 self.assertAllEqual(key.shape, [1]) 487 self.assertAllEqual(counter.shape, [2]) 488 alg = gen_stateless_random_ops_v2.stateless_random_get_alg() 489 self.assertAllEqual(alg.shape, []) 490 491 def assertDTypeEqual(self, a, b): 492 self.assertEqual(dtypes.as_dtype(a), dtypes.as_dtype(b)) 493 494 def assertNoEqualPair(self, ls): 495 for i in range(len(ls)): 496 for j in range(i + 1, len(ls)): 497 self.assertFalse(math_ops.reduce_all(ls[i] == ls[j])) 498 499 @parameterized.parameters(['int32', 'int64']) 500 @test_util.run_v2_only 501 def testSplit(self, dtype): 502 """Test for `split`.""" 503 seed = constant_op.constant([1, 2], dtype=dtype) 504 new_seed = stateless.split(seed, 3) 505 self.assertEqual(new_seed.shape, [3, 2]) 506 self.assertDTypeEqual(new_seed.dtype, dtype) 507 self.assertNoEqualPair([seed] + array_ops.unstack(new_seed)) 508 509 @parameterized.parameters(['int32', 'int64']) 510 @test_util.run_v2_only 511 def testFoldIn(self, dtype): 512 """Test for `fold_in`.""" 513 orig_seed = constant_op.constant([1, 2], dtype='int32') 514 seed = stateless.fold_in(orig_seed, constant_op.constant(3, dtype=dtype)) 515 new_seeds = [] 516 new_seeds.append(seed) 517 seed = stateless.fold_in(seed, constant_op.constant(4, dtype=dtype)) 518 new_seeds.append(seed) 519 for s in new_seeds: 520 self.assertEqual(s.shape, [2]) 521 self.assertDTypeEqual(s.dtype, dtype) 522 self.assertNoEqualPair([math_ops.cast(orig_seed, dtype)] + new_seeds) 523 524 @test_util.run_v2_only 525 def testErrors(self): 526 """Tests that proper errors are raised. 527 """ 528 shape = [2, 3] 529 with self.assertRaisesWithPredicateMatch( 530 ValueError, 531 'minval must be a scalar; got a tensor of shape '): 532 @def_function.function 533 def f(): 534 stateless.stateless_random_uniform( 535 shape=shape, seed=[1, 2], minval=array_ops.zeros(shape, 'int32'), 536 maxval=100, dtype='int32') 537 f() 538 with self.assertRaisesWithPredicateMatch( 539 ValueError, 540 'maxval must be a scalar; got a tensor of shape '): 541 @def_function.function 542 def f2(): 543 stateless.stateless_random_uniform( 544 shape=shape, seed=[1, 2], minval=0, 545 maxval=array_ops.ones(shape, 'int32') * 100, 546 dtype='int32') 547 f2() 548 549 550if __name__ == '__main__': 551 config.set_soft_device_placement(False) 552 context.context().enable_xla_devices() 553 test.main() 554