1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Operations for generating random numbers.""" 16 17from tensorflow.python.distribute import distribution_strategy_context as ds_context 18from tensorflow.python.distribute import sharded_variable 19from tensorflow.python.distribute import values_util 20from tensorflow.python.eager import context 21from tensorflow.python.framework import config 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import gen_stateful_random_ops 26from tensorflow.python.ops import gen_stateless_random_ops_v2 27from tensorflow.python.ops import math_ops 28from tensorflow.python.ops import resource_variable_ops 29from tensorflow.python.ops import stateless_random_ops 30from tensorflow.python.ops import variables 31from tensorflow.python.ops.stateless_random_ops import Algorithm 32from tensorflow.python.trackable import autotrackable 33from tensorflow.python.util import nest 34from tensorflow.python.util.tf_export import tf_export 35 36 37# A seed for random ops (stateful and stateless) will always be 1024 38# bits, all of which will be sent to the C++ code. The actual C++ 39# implementation of some algorithms may only use a lower part of the bits. 40 41UINT64_HALF_SPAN = 2**63 42MAX_INT64 = UINT64_HALF_SPAN - 1 43MIN_INT64 = -UINT64_HALF_SPAN 44UINT64_SPAN = UINT64_HALF_SPAN * 2 45# 'Variable' doesn't support uint32 or uint64 yet (due to reasons explained in 46# b/111604096 and cl/171681867), so I use signed int here. I choose int64 47# instead of int32 here because `VarHandleOp` doesn't support int32 on GPU. 48SEED_TYPE = "int64" 49SEED_MIN = MIN_INT64 50SEED_MAX = MAX_INT64 51SEED_UINT_SPAN = UINT64_SPAN 52SEED_TYPE_BITS = 64 53SEED_BIT_MASK = 0xFFFFFFFFFFFFFFFF 54SEED_SIZE = 16 # in units of SEED_TYPE 55 56 57STATE_TYPE = SEED_TYPE 58ALGORITHM_TYPE = STATE_TYPE 59PHILOX_STATE_SIZE = 3 60THREEFRY_STATE_SIZE = 2 61 62 63RNG_ALG_PHILOX = Algorithm.PHILOX.value 64RNG_ALG_THREEFRY = Algorithm.THREEFRY.value 65DEFAULT_ALGORITHM = RNG_ALG_PHILOX 66 67 68def non_deterministic_ints(shape, dtype=dtypes.int64): 69 """Non-deterministically generates some integers. 70 71 This op may use some OS-provided source of non-determinism (e.g. an RNG), so 72 each execution will give different results. 73 74 Args: 75 shape: the shape of the result. 76 dtype: (optional) the dtype of the result. 77 78 Returns: 79 a tensor whose element values are non-deterministically chosen. 80 """ 81 return gen_stateful_random_ops.non_deterministic_ints( 82 shape=shape, dtype=dtype) 83 84 85def _uint_to_int(n): 86 if isinstance(n, int) and n > SEED_MAX: 87 n = n - SEED_UINT_SPAN 88 return n 89 90 91def _make_1d_state(state_size, seed): 92 """Makes a 1-D RNG state. 93 94 Args: 95 state_size: an integer. 96 seed: an integer or 1-D tensor. 97 98 Returns: 99 a 1-D tensor of shape [state_size] and dtype STATE_TYPE. 100 """ 101 if isinstance(seed, int): 102 # chop the Python integer (infinite precision) into chunks of SEED_TYPE 103 ls = [] 104 for _ in range(state_size): 105 ls.append(seed & SEED_BIT_MASK) 106 seed >>= SEED_TYPE_BITS 107 seed = ls 108 # to avoid overflow error from ops.convert_to_tensor 109 seed = nest.map_structure(_uint_to_int, seed) 110 seed = math_ops.cast(seed, STATE_TYPE) 111 seed = array_ops.reshape(seed, [-1]) 112 seed = seed[0:state_size] 113 # Padding with zeros on the *left* if too short. Padding on the right would 114 # cause a small seed to be used as the "counter" while the "key" is always 115 # zero (for counter-based RNG algorithms), because in the current memory 116 # layout counter is stored before key. In such a situation two RNGs with 117 # two different small seeds may generate overlapping outputs. 118 seed_size = seed.shape[0] 119 if seed_size is None: 120 seed_size = array_ops.shape(seed)[0] 121 padding_size = math_ops.maximum(state_size - seed_size, 0) 122 padding = array_ops.zeros([padding_size], seed.dtype) 123 # can't use `pad` because it doesn't support integer dtypes on GPU 124 seed = array_ops.concat([padding, seed], axis=0) 125 seed.set_shape([state_size]) 126 return seed 127 128 129def _get_counter_size(alg): 130 if alg == RNG_ALG_PHILOX: 131 return 2 132 elif alg == RNG_ALG_THREEFRY: 133 return 1 134 else: 135 raise ValueError( 136 f"Argument `alg` got unsupported value {alg}. Supported values are " 137 f"{RNG_ALG_PHILOX} for the Philox algorithm and {RNG_ALG_THREEFRY} for " 138 f"the ThreeFry algorithm.") 139 140 141def _get_state_size(alg): 142 if alg == RNG_ALG_PHILOX: 143 return PHILOX_STATE_SIZE 144 elif alg == RNG_ALG_THREEFRY: 145 return THREEFRY_STATE_SIZE 146 else: 147 raise ValueError( 148 f"Argument `alg` got unsupported value {alg}. Supported values are " 149 f"{RNG_ALG_PHILOX} for the Philox algorithm and {RNG_ALG_THREEFRY} for " 150 f"the ThreeFry algorithm.") 151 152 153def _check_state_shape(shape, alg): 154 if isinstance(alg, ops.Tensor) and not context.executing_eagerly(): 155 return 156 shape.assert_is_compatible_with([_get_state_size(int(alg))]) 157 158 159def _make_state_from_seed(seed, alg): 160 return _make_1d_state(_get_state_size(alg), seed) 161 162 163@tf_export("random.create_rng_state", "random.experimental.create_rng_state") 164def create_rng_state(seed, alg): 165 """Creates a RNG state from an integer or a vector. 166 167 Example: 168 169 >>> tf.random.create_rng_state( 170 ... 1234, "philox") 171 <tf.Tensor: shape=(3,), dtype=int64, numpy=array([1234, 0, 0])> 172 >>> tf.random.create_rng_state( 173 ... [12, 34], "threefry") 174 <tf.Tensor: shape=(2,), dtype=int64, numpy=array([12, 34])> 175 176 Args: 177 seed: an integer or 1-D numpy array. 178 alg: the RNG algorithm. Can be a string, an `Algorithm` or an integer. 179 180 Returns: 181 a 1-D numpy array whose size depends on the algorithm. 182 """ 183 alg = stateless_random_ops.convert_alg_to_int(alg) 184 return _make_state_from_seed(seed, alg) 185 186 187def _shape_tensor(shape): 188 """Convert to an int32 or int64 tensor, defaulting to int64 if empty.""" 189 if isinstance(shape, (tuple, list)) and not shape: 190 dtype = dtypes.int64 191 else: 192 dtype = None 193 return ops.convert_to_tensor(shape, dtype=dtype, name="shape") 194 195 196def _convert_to_state_tensor(t): 197 # to avoid out-of-range error from ops.convert_to_tensor 198 t = nest.map_structure(_uint_to_int, t) 199 return math_ops.cast(t, STATE_TYPE) 200 201 202def get_replica_id(): 203 rctx = ds_context.get_replica_context() 204 if rctx is None: 205 return None 206 return rctx.replica_id_in_sync_group 207 208 209@tf_export("random.Generator", "random.experimental.Generator") 210class Generator(autotrackable.AutoTrackable): 211 """Random-number generator. 212 213 Example: 214 215 Creating a generator from a seed: 216 217 >>> g = tf.random.Generator.from_seed(1234) 218 >>> g.normal(shape=(2, 3)) 219 <tf.Tensor: shape=(2, 3), dtype=float32, numpy= 220 array([[ 0.9356609 , 1.0854305 , -0.93788373], 221 [-0.5061547 , 1.3169702 , 0.7137579 ]], dtype=float32)> 222 223 Creating a generator from a non-deterministic state: 224 225 >>> g = tf.random.Generator.from_non_deterministic_state() 226 >>> g.normal(shape=(2, 3)) 227 <tf.Tensor: shape=(2, 3), dtype=float32, numpy=...> 228 229 All the constructors allow explicitly choosing an Random-Number-Generation 230 (RNG) algorithm. Supported algorithms are `"philox"` and `"threefry"`. For 231 example: 232 233 >>> g = tf.random.Generator.from_seed(123, alg="philox") 234 >>> g.normal(shape=(2, 3)) 235 <tf.Tensor: shape=(2, 3), dtype=float32, numpy= 236 array([[ 0.8673864 , -0.29899067, -0.9310337 ], 237 [-1.5828488 , 1.2481191 , -0.6770643 ]], dtype=float32)> 238 239 CPU, GPU and TPU with the same algorithm and seed will generate the same 240 integer random numbers. Float-point results (such as the output of `normal`) 241 may have small numerical discrepancies between different devices. 242 243 This class uses a `tf.Variable` to manage its internal state. Every time 244 random numbers are generated, the state of the generator will change. For 245 example: 246 247 >>> g = tf.random.Generator.from_seed(1234) 248 >>> g.state 249 <tf.Variable ... numpy=array([1234, 0, 0])> 250 >>> g.normal(shape=(2, 3)) 251 <...> 252 >>> g.state 253 <tf.Variable ... numpy=array([2770, 0, 0])> 254 255 The shape of the state is algorithm-specific. 256 257 There is also a global generator: 258 259 >>> g = tf.random.get_global_generator() 260 >>> g.normal(shape=(2, 3)) 261 <tf.Tensor: shape=(2, 3), dtype=float32, numpy=...> 262 263 When creating a generator inside a `tf.distribute.Strategy` scope, each 264 replica will get a different stream of random numbers. 265 266 For example, in this code: 267 268 ``` 269 strat = tf.distribute.MirroredStrategy(devices=["cpu:0", "cpu:1"]) 270 with strat.scope(): 271 g = tf.random.Generator.from_seed(1) 272 def f(): 273 return g.normal([]) 274 results = strat.run(f).values 275 ``` 276 277 `results[0]` and `results[1]` will have different values. 278 279 If the generator is seeded (e.g. created via `Generator.from_seed`), the 280 random numbers will be determined by the seed, even though different replicas 281 get different numbers. One can think of a random number generated on a 282 replica as a hash of the replica ID and a "master" random number that may be 283 common to all replicas. Hence, the whole system is still deterministic. 284 285 (Note that the random numbers on different replicas are not correlated, even 286 if they are deterministically determined by the same seed. They are not 287 correlated in the sense that no matter what statistics one calculates on them, 288 there won't be any discernable correlation.) 289 290 Generators can be freely saved and restored using `tf.train.Checkpoint`. The 291 checkpoint can be restored in a distribution strategy with a different number 292 of replicas than the original strategy. If a replica ID is present in both the 293 original and the new distribution strategy, its state will be properly 294 restored (i.e. the random-number stream from the restored point will be the 295 same as that from the saving point) unless the replicas have already diverged 296 in their RNG call traces before saving (e.g. one replica has made one RNG call 297 while another has made two RNG calls). We don't have such guarantee if the 298 generator is saved in a strategy scope and restored outside of any strategy 299 scope, or vice versa. 300 301 When a generator is created within the scope of 302 `tf.distribute.experimental.ParameterServerStrategy`, the workers 303 will share the generator's state (placed on one of the parameter 304 servers). In this way the workers will still get different 305 random-number streams, as stated above. (This is similar to replicas 306 in a `tf.distribute.MirroredStrategy` sequentially accessing a 307 generator created outside the strategy.) Each RNG call on a worker 308 will incur a round-trip to a parameter server, which may have 309 performance impacts. When creating a 310 `tf.distribute.experimental.ParameterServerStrategy`, please make 311 sure that the `variable_partitioner` argument won't shard small 312 variables of shape `[2]` or `[3]` (because generator states must not 313 be sharded). Ways to avoid sharding small variables include setting 314 `variable_partitioner` to `None` or to 315 `tf.distribute.experimental.partitioners.MinSizePartitioner` with a 316 large enough `min_shard_bytes` (see 317 `tf.distribute.experimental.ParameterServerStrategy`'s documentation 318 for more details). 319 """ 320 321 @classmethod 322 def from_state(cls, state, alg): 323 """Creates a generator from a state. 324 325 See `__init__` for description of `state` and `alg`. 326 327 Args: 328 state: the new state. 329 alg: the RNG algorithm. 330 331 Returns: 332 The new generator. 333 """ 334 return cls(alg=alg, state=state) 335 336 @classmethod 337 def from_seed(cls, seed, alg=None): 338 """Creates a generator from a seed. 339 340 A seed is a 1024-bit unsigned integer represented either as a Python 341 integer or a vector of integers. Seeds shorter than 1024-bit will be 342 padded. The padding, the internal structure of a seed and the way a seed 343 is converted to a state are all opaque (unspecified). The only semantics 344 specification of seeds is that two different seeds are likely to produce 345 two independent generators (but no guarantee). 346 347 Args: 348 seed: the seed for the RNG. 349 alg: (optional) the RNG algorithm. If None, it will be auto-selected. See 350 `__init__` for its possible values. 351 352 Returns: 353 The new generator. 354 """ 355 if alg is None: 356 # TODO(b/170668986): more sophisticated algorithm selection 357 alg = DEFAULT_ALGORITHM 358 alg = stateless_random_ops.convert_alg_to_int(alg) 359 state = create_rng_state(seed, alg) 360 return cls(state=state, alg=alg) 361 362 @classmethod 363 def from_non_deterministic_state(cls, alg=None): 364 """Creates a generator by non-deterministically initializing its state. 365 366 The source of the non-determinism will be platform- and time-dependent. 367 368 Args: 369 alg: (optional) the RNG algorithm. If None, it will be auto-selected. See 370 `__init__` for its possible values. 371 372 Returns: 373 The new generator. 374 """ 375 if config.is_op_determinism_enabled(): 376 raise RuntimeError('"from_non_deterministic_state" cannot be called when ' # pylint: disable=g-doc-exception 377 "determinism is enabled.") 378 if alg is None: 379 # TODO(b/170668986): more sophisticated algorithm selection 380 alg = DEFAULT_ALGORITHM 381 alg = stateless_random_ops.convert_alg_to_int(alg) 382 state = non_deterministic_ints(shape=[_get_state_size(alg)], 383 dtype=SEED_TYPE) 384 return cls(state=state, alg=alg) 385 386 @classmethod 387 def from_key_counter(cls, key, counter, alg): 388 """Creates a generator from a key and a counter. 389 390 This constructor only applies if the algorithm is a counter-based algorithm. 391 See method `key` for the meaning of "key" and "counter". 392 393 Args: 394 key: the key for the RNG, a scalar of type STATE_TYPE. 395 counter: a vector of dtype STATE_TYPE representing the initial counter for 396 the RNG, whose length is algorithm-specific., 397 alg: the RNG algorithm. If None, it will be auto-selected. See 398 `__init__` for its possible values. 399 400 Returns: 401 The new generator. 402 """ 403 counter = _convert_to_state_tensor(counter) 404 key = _convert_to_state_tensor(key) 405 alg = stateless_random_ops.convert_alg_to_int(alg) 406 counter.shape.assert_is_compatible_with([_get_state_size(alg) - 1]) 407 key.shape.assert_is_compatible_with([]) 408 key = array_ops.reshape(key, [1]) 409 state = array_ops.concat([counter, key], 0) 410 return cls(state=state, alg=alg) 411 412 def __init__(self, copy_from=None, state=None, alg=None): 413 """Creates a generator. 414 415 The new generator will be initialized by one of the following ways, with 416 decreasing precedence: 417 (1) If `copy_from` is not None, the new generator is initialized by copying 418 information from another generator. 419 (2) If `state` and `alg` are not None (they must be set together), the new 420 generator is initialized by a state. 421 422 Args: 423 copy_from: a generator to be copied from. 424 state: a vector of dtype STATE_TYPE representing the initial state of the 425 RNG, whose length and semantics are algorithm-specific. If it's a 426 variable, the generator will reuse it instead of creating a new 427 variable. 428 alg: the RNG algorithm. Possible values are 429 `tf.random.Algorithm.PHILOX` for the Philox algorithm and 430 `tf.random.Algorithm.THREEFRY` for the ThreeFry algorithm 431 (see paper 'Parallel Random Numbers: As Easy as 1, 2, 3' 432 [https://www.thesalmons.org/john/random123/papers/random123sc11.pdf]). 433 The string names `"philox"` and `"threefry"` can also be used. 434 Note `PHILOX` guarantees the same numbers are produced (given 435 the same random state) across all architectures (CPU, GPU, XLA etc). 436 """ 437 # TODO(b/175072242): Remove distribution-strategy dependencies in this file. 438 if ds_context.has_strategy(): 439 self._distribution_strategy = ds_context.get_strategy() 440 else: 441 self._distribution_strategy = None 442 if copy_from is not None: 443 # All other arguments should be None 444 assert (alg or state) is None 445 self._state_var = self._create_variable(copy_from.state, dtype=STATE_TYPE, 446 trainable=False) 447 self._alg = copy_from.algorithm 448 else: 449 assert alg is not None and state is not None 450 alg = stateless_random_ops.convert_alg_to_int(alg) 451 if isinstance(state, variables.Variable): 452 _check_state_shape(state.shape, alg) 453 self._state_var = state 454 else: 455 state = _convert_to_state_tensor(state) 456 _check_state_shape(state.shape, alg) 457 self._state_var = self._create_variable(state, dtype=STATE_TYPE, 458 trainable=False) 459 self._alg = alg 460 461 def _create_variable(self, *args, **kwargs): 462 """Creates a variable. 463 464 Args: 465 *args: positional arguments passed along to `variables.Variable. 466 **kwargs: keyword arguments passed along to `variables.Variable. 467 468 Returns: 469 The created variable. 470 """ 471 with ops.name_scope("random_generator"): 472 # Make sure we don't change this name since Keras was using this name 473 # to filter out the state variable. 474 kwargs["name"] = "StateVar" 475 v = variables.Variable(*args, **kwargs) 476 if isinstance(v, sharded_variable.ShardedVariable): 477 # RNG state is an atomic entity representing a 128-bit or 478 # 192-bit value, so it mustn't be sharded. 479 raise ValueError( 480 "tf.random.Generator state is sharded, which is not allowed. When " 481 "creating a tf.distribute.experimental.ParameterServerStrategy, " 482 "please make sure that the `variable_partitioner` " 483 "argument won't shard a " 484 "small variable of shape [2] or [3]. Ways to avoid sharding small " 485 "variables include setting `variable_partitioner` to None or to " 486 "tf.distribute.experimental.partitioners.MinSizePartitioner with a " 487 "large enough `min_shard_bytes`.") 488 return v 489 490 def reset(self, state): 491 """Resets the generator by a new state. 492 493 See `__init__` for the meaning of "state". 494 495 Args: 496 state: the new state. 497 """ 498 state = _convert_to_state_tensor(state) 499 state.shape.assert_is_compatible_with([_get_state_size(self.algorithm)]) 500 self._state_var.assign(state) 501 502 def reset_from_seed(self, seed): 503 """Resets the generator by a new seed. 504 505 See `from_seed` for the meaning of "seed". 506 507 Args: 508 seed: the new seed. 509 """ 510 state = create_rng_state(seed, self.algorithm) 511 self._state_var.assign(state) 512 513 def reset_from_key_counter(self, key, counter): 514 """Resets the generator by a new key-counter pair. 515 516 See `from_key_counter` for the meaning of "key" and "counter". 517 518 Args: 519 key: the new key. 520 counter: the new counter. 521 """ 522 counter = _convert_to_state_tensor(counter) 523 key = _convert_to_state_tensor(key) 524 counter.shape.assert_is_compatible_with( 525 [_get_state_size(self.algorithm) - 1]) 526 key.shape.assert_is_compatible_with([]) 527 key = array_ops.reshape(key, [1]) 528 state = array_ops.concat([counter, key], 0) 529 self._state_var.assign(state) 530 531 @property 532 def state(self): 533 """The internal state of the RNG.""" 534 return self._state_var 535 536 @property 537 def algorithm(self): 538 """The RNG algorithm id (a Python integer or scalar integer Tensor).""" 539 return self._alg 540 541 def _standard_normal(self, shape, dtype): 542 key, counter = self._prepare_key_counter(shape) 543 return gen_stateless_random_ops_v2.stateless_random_normal_v2( 544 shape, key=key, counter=counter, dtype=dtype, alg=self.algorithm) 545 546 @property 547 def key(self): 548 """The 'key' part of the state of a counter-based RNG. 549 550 For a counter-base RNG algorithm such as Philox and ThreeFry (as 551 described in paper 'Parallel Random Numbers: As Easy as 1, 2, 3' 552 [https://www.thesalmons.org/john/random123/papers/random123sc11.pdf]), 553 the RNG state consists of two parts: counter and key. The output is 554 generated via the formula: output=hash(key, counter), i.e. a hashing of 555 the counter parametrized by the key. Two RNGs with two different keys can 556 be thought as generating two independent random-number streams (a stream 557 is formed by increasing the counter). 558 559 Returns: 560 A scalar which is the 'key' part of the state, if the RNG algorithm is 561 counter-based; otherwise it raises a ValueError. 562 """ 563 alg = self.algorithm 564 if alg == RNG_ALG_PHILOX or alg == RNG_ALG_THREEFRY: 565 return self._state_var[-1] 566 else: 567 raise ValueError( 568 f"This generator uses an unsupported algorithm {alg}. Supported " 569 f"values are {RNG_ALG_PHILOX} for the Philox algorithm and " 570 f"{RNG_ALG_THREEFRY} for the ThreeFry algorithm.") 571 572 def _skip_single_var(self, var, delta): 573 resource_variable_ops.variable_accessed(var) 574 # TODO(wangpeng): Cache the cast algorithm instead of casting everytime. 575 return gen_stateful_random_ops.rng_read_and_skip( 576 var.handle, 577 alg=math_ops.cast(self.algorithm, dtypes.int32), 578 delta=math_ops.cast(delta, dtypes.uint64)) 579 580 def skip(self, delta): 581 """Advance the counter of a counter-based RNG. 582 583 Args: 584 delta: the amount of advancement. The state of the RNG after 585 `skip(n)` will be the same as that after `normal([n])` 586 (or any other distribution). The actual increment added to the 587 counter is an unspecified implementation detail. 588 589 Returns: 590 A `Tensor` of type `int64`. 591 """ 592 593 def update_fn(v): 594 return self._skip_single_var(v, delta) 595 # TODO(b/170515001): Always call strategy.extended.update after calling it 596 # from both replica context and cross-replica context is supported. 597 if values_util.is_saving_non_distributed(): 598 # Assumes replica context with replica_id=0, since we only save the first 599 # replica. 600 return update_fn(self.state) 601 if self._distribution_strategy is not None: 602 with ds_context.enter_or_assert_strategy(self._distribution_strategy): 603 if ds_context.in_cross_replica_context(): 604 # Code that operates on all replicas of a variable cannot be saved 605 # without retracing. 606 values_util.mark_as_unsaveable() 607 if (ds_context.in_cross_replica_context() or 608 "CentralStorage" in type(self._distribution_strategy).__name__): 609 # In cross-replica context we need to use strategy.extended.update. 610 # In CentralStorageStrategy we also need to use 611 # strategy.extended.update (even for replica context), 612 # because variable updates here must be within merge_call. 613 return ds_context.get_strategy().extended.update( 614 self.state, update_fn) 615 return update_fn(self.state) 616 617 def _preprocess_key(self, key): 618 if self._distribution_strategy is None: 619 return key 620 with ds_context.enter_or_assert_strategy(self._distribution_strategy): 621 replica_id = get_replica_id() 622 if replica_id is not None: 623 replica_id = array_ops.stack([replica_id, 0], axis=0) 624 replica_id = math_ops.cast(replica_id, dtypes.uint64) 625 # Conceptually: key = hash(key, replica_id) 626 key = gen_stateless_random_ops_v2.stateless_random_uniform_full_int_v2( 627 shape=[1], key=key, counter=replica_id, dtype=dtypes.uint64, 628 alg=self.algorithm) 629 return key 630 631 def _prepare_key_counter(self, shape): 632 delta = math_ops.reduce_prod(shape) 633 counter_key = self.skip(delta) 634 counter_size = _get_counter_size(self.algorithm) 635 counter = array_ops.bitcast(counter_key[:counter_size], dtypes.uint64) 636 key = array_ops.bitcast(counter_key[counter_size:counter_size + 1], 637 dtypes.uint64) 638 key = self._preprocess_key(key) 639 return key, counter 640 641 # The following functions return a tensor and as a side effect update 642 # self._state_var. 643 def normal(self, shape, mean=0.0, stddev=1.0, dtype=dtypes.float32, 644 name=None): 645 """Outputs random values from a normal distribution. 646 647 Args: 648 shape: A 1-D integer Tensor or Python array. The shape of the output 649 tensor. 650 mean: A 0-D Tensor or Python value of type `dtype`. The mean of the normal 651 distribution. 652 stddev: A 0-D Tensor or Python value of type `dtype`. The standard 653 deviation of the normal distribution. 654 dtype: The type of the output. 655 name: A name for the operation (optional). 656 657 Returns: 658 A tensor of the specified shape filled with random normal values. 659 """ 660 with ops.name_scope(name, "stateful_normal", [shape, mean, stddev]) as name: 661 shape = _shape_tensor(shape) 662 mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean") 663 stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev") 664 rnd = self._standard_normal(shape, dtype=dtype) 665 return math_ops.add(rnd * stddev, mean, name=name) 666 667 def _truncated_normal(self, shape, dtype): 668 key, counter = self._prepare_key_counter(shape) 669 return gen_stateless_random_ops_v2.stateless_truncated_normal_v2( 670 shape=shape, key=key, counter=counter, dtype=dtype, alg=self.algorithm) 671 672 def truncated_normal(self, shape, 673 mean=0.0, 674 stddev=1.0, 675 dtype=dtypes.float32, 676 name=None): 677 """Outputs random values from a truncated normal distribution. 678 679 The generated values follow a normal distribution with specified mean and 680 standard deviation, except that values whose magnitude is more than 681 2 standard deviations from the mean are dropped and re-picked. 682 683 Args: 684 shape: A 1-D integer Tensor or Python array. The shape of the output 685 tensor. 686 mean: A 0-D Tensor or Python value of type `dtype`. The mean of the 687 truncated normal distribution. 688 stddev: A 0-D Tensor or Python value of type `dtype`. The standard 689 deviation of the normal distribution, before truncation. 690 dtype: The type of the output. 691 name: A name for the operation (optional). 692 693 Returns: 694 A tensor of the specified shape filled with random truncated normal 695 values. 696 """ 697 with ops.name_scope( 698 name, "truncated_normal", [shape, mean, stddev]) as name: 699 shape_tensor = _shape_tensor(shape) 700 mean_tensor = ops.convert_to_tensor(mean, dtype=dtype, name="mean") 701 stddev_tensor = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev") 702 rnd = self._truncated_normal(shape_tensor, dtype=dtype) 703 mul = rnd * stddev_tensor 704 return math_ops.add(mul, mean_tensor, name=name) 705 706 def _uniform(self, shape, dtype): 707 key, counter = self._prepare_key_counter(shape) 708 return gen_stateless_random_ops_v2.stateless_random_uniform_v2( 709 shape=shape, key=key, counter=counter, dtype=dtype, alg=self.algorithm) 710 711 def _uniform_full_int(self, shape, dtype, name=None): 712 key, counter = self._prepare_key_counter(shape) 713 return gen_stateless_random_ops_v2.stateless_random_uniform_full_int_v2( 714 shape=shape, 715 key=key, 716 counter=counter, 717 dtype=dtype, 718 alg=self.algorithm, 719 name=name) 720 721 def uniform(self, shape, minval=0, maxval=None, 722 dtype=dtypes.float32, name=None): 723 """Outputs random values from a uniform distribution. 724 725 The generated values follow a uniform distribution in the range 726 `[minval, maxval)`. The lower bound `minval` is included in the range, while 727 the upper bound `maxval` is excluded. (For float numbers especially 728 low-precision types like bfloat16, because of 729 rounding, the result may sometimes include `maxval`.) 730 731 For floats, the default range is `[0, 1)`. For ints, at least `maxval` must 732 be specified explicitly. 733 734 In the integer case, the random integers are slightly biased unless 735 `maxval - minval` is an exact power of two. The bias is small for values of 736 `maxval - minval` significantly smaller than the range of the output (either 737 `2**32` or `2**64`). 738 739 For full-range random integers, pass `minval=None` and `maxval=None` with an 740 integer `dtype` (for integer dtypes, `minval` and `maxval` must be both 741 `None` or both not `None`). 742 743 Args: 744 shape: A 1-D integer Tensor or Python array. The shape of the output 745 tensor. 746 minval: A Tensor or Python value of type `dtype`, broadcastable with 747 `shape` (for integer types, broadcasting is not supported, so it needs 748 to be a scalar). The lower bound (included) on the range of random 749 values to generate. Pass `None` for full-range integers. Defaults to 0. 750 maxval: A Tensor or Python value of type `dtype`, broadcastable with 751 `shape` (for integer types, broadcasting is not supported, so it needs 752 to be a scalar). The upper bound (excluded) on the range of random 753 values to generate. Pass `None` for full-range integers. Defaults to 1 754 if `dtype` is floating point. 755 dtype: The type of the output. 756 name: A name for the operation (optional). 757 758 Returns: 759 A tensor of the specified shape filled with random uniform values. 760 761 Raises: 762 ValueError: If `dtype` is integral and `maxval` is not specified. 763 """ 764 dtype = dtypes.as_dtype(dtype) 765 if dtype.is_integer: 766 if (minval is None) != (maxval is None): 767 raise ValueError("For integer dtype {}, minval and maxval must be both " 768 "`None` or both non-`None`; got minval={} and " 769 "maxval={}".format(dtype, minval, maxval)) 770 elif maxval is None: 771 maxval = 1 772 with ops.name_scope(name, "stateful_uniform", 773 [shape, minval, maxval]) as name: 774 shape = _shape_tensor(shape) 775 if dtype.is_integer and minval is None: 776 return self._uniform_full_int(shape=shape, dtype=dtype, name=name) 777 minval = ops.convert_to_tensor(minval, dtype=dtype, name="min") 778 maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max") 779 if dtype.is_integer: 780 key, counter = self._prepare_key_counter(shape) 781 return gen_stateless_random_ops_v2.stateless_random_uniform_int_v2( 782 shape=shape, 783 key=key, 784 counter=counter, 785 minval=minval, 786 maxval=maxval, 787 alg=self.algorithm, 788 name=name) 789 else: 790 rnd = self._uniform(shape=shape, dtype=dtype) 791 return math_ops.add(rnd * (maxval - minval), minval, name=name) 792 793 def uniform_full_int(self, shape, dtype=dtypes.uint64, name=None): 794 """Uniform distribution on an integer type's entire range. 795 796 This method is the same as setting `minval` and `maxval` to `None` in the 797 `uniform` method. 798 799 Args: 800 shape: the shape of the output. 801 dtype: (optional) the integer type, default to uint64. 802 name: (optional) the name of the node. 803 804 Returns: 805 A tensor of random numbers of the required shape. 806 """ 807 dtype = dtypes.as_dtype(dtype) 808 with ops.name_scope(name, "stateful_uniform_full_int", 809 [shape]) as name: 810 shape = _shape_tensor(shape) 811 return self._uniform_full_int(shape=shape, dtype=dtype, name=name) 812 813 def binomial(self, shape, counts, probs, dtype=dtypes.int32, name=None): 814 """Outputs random values from a binomial distribution. 815 816 The generated values follow a binomial distribution with specified count and 817 probability of success parameters. 818 819 Example: 820 821 ```python 822 counts = [10., 20.] 823 # Probability of success. 824 probs = [0.8] 825 826 rng = tf.random.Generator.from_seed(seed=234) 827 binomial_samples = rng.binomial(shape=[2], counts=counts, probs=probs) 828 829 830 counts = ... # Shape [3, 1, 2] 831 probs = ... # Shape [1, 4, 2] 832 shape = [3, 4, 3, 4, 2] 833 rng = tf.random.Generator.from_seed(seed=1717) 834 # Sample shape will be [3, 4, 3, 4, 2] 835 binomial_samples = rng.binomial(shape=shape, counts=counts, probs=probs) 836 ``` 837 838 839 Args: 840 shape: A 1-D integer Tensor or Python array. The shape of the output 841 tensor. 842 counts: Tensor. The counts of the binomial distribution. Must be 843 broadcastable with `probs`, and broadcastable with the rightmost 844 dimensions of `shape`. 845 probs: Tensor. The probability of success for the 846 binomial distribution. Must be broadcastable with `counts` and 847 broadcastable with the rightmost dimensions of `shape`. 848 dtype: The type of the output. Default: tf.int32 849 name: A name for the operation (optional). 850 851 Returns: 852 samples: A Tensor of the specified shape filled with random binomial 853 values. For each i, each samples[i, ...] is an independent draw from 854 the binomial distribution on counts[i] trials with probability of 855 success probs[i]. 856 """ 857 dtype = dtypes.as_dtype(dtype) 858 with ops.name_scope(name, "binomial", [shape, counts, probs]) as name: 859 counts = ops.convert_to_tensor(counts, name="counts") 860 probs = ops.convert_to_tensor(probs, name="probs") 861 shape_tensor = _shape_tensor(shape) 862 return gen_stateful_random_ops.stateful_random_binomial( 863 self.state.handle, 864 self.algorithm, 865 shape=shape_tensor, 866 counts=counts, 867 probs=probs, 868 dtype=dtype, 869 name=name) 870 871 # TODO(wangpeng): implement other distributions 872 873 def _make_int64_keys(self, shape=()): 874 # New independent keys are generated via 875 # `new_key[i] = hash(old_key, counter+i)`, which is exactly what 876 # `uniform_full_int(dtype=int64)` does for PhiloxRandom_64_128_128 and 877 # ThreeFry_64_64_64. 878 return self.uniform_full_int(shape=shape, dtype=dtypes.int64) 879 880 def make_seeds(self, count=1): 881 """Generates seeds for stateless random ops. 882 883 For example: 884 885 ```python 886 seeds = get_global_generator().make_seeds(count=10) 887 for i in range(10): 888 seed = seeds[:, i] 889 numbers = stateless_random_normal(shape=[2, 3], seed=seed) 890 ... 891 ``` 892 893 Args: 894 count: the number of seed pairs (note that stateless random ops need a 895 pair of seeds to invoke). 896 897 Returns: 898 A tensor of shape [2, count] and dtype int64. 899 """ 900 alg = self.algorithm 901 if alg == RNG_ALG_PHILOX or alg == RNG_ALG_THREEFRY: 902 keys = self._make_int64_keys(shape=[count]) 903 # The two seeds for stateless random ops don't have individual semantics 904 # and are scrambled together, so setting one to zero is fine. 905 zeros = array_ops.zeros_like(keys) 906 return array_ops.stack([keys, zeros]) 907 else: 908 raise ValueError( 909 f"This generator uses an unsupported algorithm {alg}. Supported " 910 f"values are {RNG_ALG_PHILOX} for the Philox algorithm and " 911 f"{RNG_ALG_THREEFRY} for the ThreeFry algorithm.") 912 913 def split(self, count=1): 914 """Returns a list of independent `Generator` objects. 915 916 Two generators are independent of each other in the sense that the 917 random-number streams they generate don't have statistically detectable 918 correlations. The new generators are also independent of the old one. 919 The old generator's state will be changed (like other random-number 920 generating methods), so two calls of `split` will return different 921 new generators. 922 923 For example: 924 925 ```python 926 gens = get_global_generator().split(count=10) 927 for gen in gens: 928 numbers = gen.normal(shape=[2, 3]) 929 # ... 930 gens2 = get_global_generator().split(count=10) 931 # gens2 will be different from gens 932 ``` 933 934 The new generators will be put on the current device (possible different 935 from the old generator's), for example: 936 937 ```python 938 with tf.device("/device:CPU:0"): 939 gen = Generator(seed=1234) # gen is on CPU 940 with tf.device("/device:GPU:0"): 941 gens = gen.split(count=10) # gens are on GPU 942 ``` 943 944 Args: 945 count: the number of generators to return. 946 947 Returns: 948 A list (length `count`) of `Generator` objects independent of each other. 949 The new generators have the same RNG algorithm as the old one. 950 """ 951 def _key_to_state(alg, key): 952 # Padding with zeros on the left. The zeros will be the counter. 953 return [0] * (_get_state_size(alg) - 1) + [key] 954 955 alg = self.algorithm 956 if alg == RNG_ALG_PHILOX or alg == RNG_ALG_THREEFRY: 957 keys = self._make_int64_keys(shape=[count]) 958 return [Generator(state=_key_to_state(alg, key), alg=alg) 959 for key in array_ops.unstack(keys, num=count)] 960 else: 961 raise ValueError( 962 f"This generator uses an unsupported algorithm {alg}. Supported " 963 f"values are {RNG_ALG_PHILOX} for the Philox algorithm and " 964 f"{RNG_ALG_THREEFRY} for the ThreeFry algorithm.") 965 966 967# It's not safe to create TF ops before `init_google` is called, so this is 968# initialized to None and get a value the first time `get_global_generator` is 969# called. 970global_generator = None 971 972 973@tf_export("random.get_global_generator", 974 "random.experimental.get_global_generator") 975def get_global_generator(): 976 """Retrieves the global generator. 977 978 This function will create the global generator the first time it is called, 979 and the generator will be placed at the default device at that time, so one 980 needs to be careful when this function is first called. Using a generator 981 placed on a less-ideal device will incur performance regression. 982 983 Returns: 984 The global `tf.random.Generator` object. 985 """ 986 global global_generator 987 if global_generator is None: 988 if config.is_op_determinism_enabled(): 989 raise RuntimeError('"get_global_generator" cannot be called if ' # pylint: disable=g-doc-exception 990 "determinism is enabled, unless " 991 '"set_global_generator" has already been called. ' 992 'Please call "set_global_generator" first.') 993 with ops.init_scope(): 994 global_generator = Generator.from_non_deterministic_state() 995 return global_generator 996 997 998@tf_export("random.set_global_generator", 999 "random.experimental.set_global_generator") 1000def set_global_generator(generator): 1001 """Replaces the global generator with another `Generator` object. 1002 1003 This function replaces the global generator with the provided `generator` 1004 object. 1005 A random number generator utilizes a `tf.Variable` object to store its state. 1006 The user shall be aware of caveats how `set_global_generator` interacts with 1007 `tf.function`: 1008 1009 - tf.function puts restrictions on Variable creation thus one cannot freely 1010 create a new random generator instance inside `tf.function`. 1011 To call `set_global_generator` inside `tf.function`, the generator instance 1012 must have already been created eagerly. 1013 - tf.function captures the Variable during trace-compilation, thus a compiled 1014 f.function will not be affected `set_global_generator` as demonstrated by 1015 random_test.py/RandomTest.testResetGlobalGeneratorBadWithDefun . 1016 1017 For most use cases, avoid calling `set_global_generator` after program 1018 initialization, and prefer to reset the state of the existing global generator 1019 instead, such as, 1020 1021 >>> rng = tf.random.get_global_generator() 1022 >>> rng.reset_from_seed(30) 1023 1024 1025 Args: 1026 generator: the new `Generator` object. 1027 """ 1028 global global_generator 1029 global_generator = generator 1030