xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/stateful_random_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Operations for generating random numbers."""
16
17from tensorflow.python.distribute import distribution_strategy_context as ds_context
18from tensorflow.python.distribute import sharded_variable
19from tensorflow.python.distribute import values_util
20from tensorflow.python.eager import context
21from tensorflow.python.framework import config
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import gen_stateful_random_ops
26from tensorflow.python.ops import gen_stateless_random_ops_v2
27from tensorflow.python.ops import math_ops
28from tensorflow.python.ops import resource_variable_ops
29from tensorflow.python.ops import stateless_random_ops
30from tensorflow.python.ops import variables
31from tensorflow.python.ops.stateless_random_ops import Algorithm
32from tensorflow.python.trackable import autotrackable
33from tensorflow.python.util import nest
34from tensorflow.python.util.tf_export import tf_export
35
36
37# A seed for random ops (stateful and stateless) will always be 1024
38# bits, all of which will be sent to the C++ code. The actual C++
39# implementation of some algorithms may only use a lower part of the bits.
40
41UINT64_HALF_SPAN = 2**63
42MAX_INT64 = UINT64_HALF_SPAN - 1
43MIN_INT64 = -UINT64_HALF_SPAN
44UINT64_SPAN = UINT64_HALF_SPAN * 2
45# 'Variable' doesn't support uint32 or uint64 yet (due to reasons explained in
46# b/111604096 and cl/171681867), so I use signed int here. I choose int64
47# instead of int32 here because `VarHandleOp` doesn't support int32 on GPU.
48SEED_TYPE = "int64"
49SEED_MIN = MIN_INT64
50SEED_MAX = MAX_INT64
51SEED_UINT_SPAN = UINT64_SPAN
52SEED_TYPE_BITS = 64
53SEED_BIT_MASK = 0xFFFFFFFFFFFFFFFF
54SEED_SIZE = 16  # in units of SEED_TYPE
55
56
57STATE_TYPE = SEED_TYPE
58ALGORITHM_TYPE = STATE_TYPE
59PHILOX_STATE_SIZE = 3
60THREEFRY_STATE_SIZE = 2
61
62
63RNG_ALG_PHILOX = Algorithm.PHILOX.value
64RNG_ALG_THREEFRY = Algorithm.THREEFRY.value
65DEFAULT_ALGORITHM = RNG_ALG_PHILOX
66
67
68def non_deterministic_ints(shape, dtype=dtypes.int64):
69  """Non-deterministically generates some integers.
70
71  This op may use some OS-provided source of non-determinism (e.g. an RNG), so
72  each execution will give different results.
73
74  Args:
75    shape: the shape of the result.
76    dtype: (optional) the dtype of the result.
77
78  Returns:
79    a tensor whose element values are non-deterministically chosen.
80  """
81  return gen_stateful_random_ops.non_deterministic_ints(
82      shape=shape, dtype=dtype)
83
84
85def _uint_to_int(n):
86  if isinstance(n, int) and n > SEED_MAX:
87    n = n - SEED_UINT_SPAN
88  return n
89
90
91def _make_1d_state(state_size, seed):
92  """Makes a 1-D RNG state.
93
94  Args:
95    state_size: an integer.
96    seed: an integer or 1-D tensor.
97
98  Returns:
99    a 1-D tensor of shape [state_size] and dtype STATE_TYPE.
100  """
101  if isinstance(seed, int):
102    # chop the Python integer (infinite precision) into chunks of SEED_TYPE
103    ls = []
104    for _ in range(state_size):
105      ls.append(seed & SEED_BIT_MASK)
106      seed >>= SEED_TYPE_BITS
107    seed = ls
108  # to avoid overflow error from ops.convert_to_tensor
109  seed = nest.map_structure(_uint_to_int, seed)
110  seed = math_ops.cast(seed, STATE_TYPE)
111  seed = array_ops.reshape(seed, [-1])
112  seed = seed[0:state_size]
113  # Padding with zeros on the *left* if too short. Padding on the right would
114  # cause a small seed to be used as the "counter" while the "key" is always
115  # zero (for counter-based RNG algorithms), because in the current memory
116  # layout counter is stored before key. In such a situation two RNGs with
117  # two different small seeds may generate overlapping outputs.
118  seed_size = seed.shape[0]
119  if seed_size is None:
120    seed_size = array_ops.shape(seed)[0]
121  padding_size = math_ops.maximum(state_size - seed_size, 0)
122  padding = array_ops.zeros([padding_size], seed.dtype)
123  # can't use `pad` because it doesn't support integer dtypes on GPU
124  seed = array_ops.concat([padding, seed], axis=0)
125  seed.set_shape([state_size])
126  return seed
127
128
129def _get_counter_size(alg):
130  if alg == RNG_ALG_PHILOX:
131    return 2
132  elif alg == RNG_ALG_THREEFRY:
133    return 1
134  else:
135    raise ValueError(
136        f"Argument `alg` got unsupported value {alg}. Supported values are "
137        f"{RNG_ALG_PHILOX} for the Philox algorithm and {RNG_ALG_THREEFRY} for "
138        f"the ThreeFry algorithm.")
139
140
141def _get_state_size(alg):
142  if alg == RNG_ALG_PHILOX:
143    return PHILOX_STATE_SIZE
144  elif alg == RNG_ALG_THREEFRY:
145    return THREEFRY_STATE_SIZE
146  else:
147    raise ValueError(
148        f"Argument `alg` got unsupported value {alg}. Supported values are "
149        f"{RNG_ALG_PHILOX} for the Philox algorithm and {RNG_ALG_THREEFRY} for "
150        f"the ThreeFry algorithm.")
151
152
153def _check_state_shape(shape, alg):
154  if isinstance(alg, ops.Tensor) and not context.executing_eagerly():
155    return
156  shape.assert_is_compatible_with([_get_state_size(int(alg))])
157
158
159def _make_state_from_seed(seed, alg):
160  return _make_1d_state(_get_state_size(alg), seed)
161
162
163@tf_export("random.create_rng_state", "random.experimental.create_rng_state")
164def create_rng_state(seed, alg):
165  """Creates a RNG state from an integer or a vector.
166
167  Example:
168
169  >>> tf.random.create_rng_state(
170  ...     1234, "philox")
171  <tf.Tensor: shape=(3,), dtype=int64, numpy=array([1234,    0,    0])>
172  >>> tf.random.create_rng_state(
173  ...     [12, 34], "threefry")
174  <tf.Tensor: shape=(2,), dtype=int64, numpy=array([12, 34])>
175
176  Args:
177    seed: an integer or 1-D numpy array.
178    alg: the RNG algorithm. Can be a string, an `Algorithm` or an integer.
179
180  Returns:
181    a 1-D numpy array whose size depends on the algorithm.
182  """
183  alg = stateless_random_ops.convert_alg_to_int(alg)
184  return _make_state_from_seed(seed, alg)
185
186
187def _shape_tensor(shape):
188  """Convert to an int32 or int64 tensor, defaulting to int64 if empty."""
189  if isinstance(shape, (tuple, list)) and not shape:
190    dtype = dtypes.int64
191  else:
192    dtype = None
193  return ops.convert_to_tensor(shape, dtype=dtype, name="shape")
194
195
196def _convert_to_state_tensor(t):
197  # to avoid out-of-range error from ops.convert_to_tensor
198  t = nest.map_structure(_uint_to_int, t)
199  return math_ops.cast(t, STATE_TYPE)
200
201
202def get_replica_id():
203  rctx = ds_context.get_replica_context()
204  if rctx is None:
205    return None
206  return rctx.replica_id_in_sync_group
207
208
209@tf_export("random.Generator", "random.experimental.Generator")
210class Generator(autotrackable.AutoTrackable):
211  """Random-number generator.
212
213  Example:
214
215  Creating a generator from a seed:
216
217  >>> g = tf.random.Generator.from_seed(1234)
218  >>> g.normal(shape=(2, 3))
219  <tf.Tensor: shape=(2, 3), dtype=float32, numpy=
220  array([[ 0.9356609 ,  1.0854305 , -0.93788373],
221         [-0.5061547 ,  1.3169702 ,  0.7137579 ]], dtype=float32)>
222
223  Creating a generator from a non-deterministic state:
224
225  >>> g = tf.random.Generator.from_non_deterministic_state()
226  >>> g.normal(shape=(2, 3))
227  <tf.Tensor: shape=(2, 3), dtype=float32, numpy=...>
228
229  All the constructors allow explicitly choosing an Random-Number-Generation
230  (RNG) algorithm. Supported algorithms are `"philox"` and `"threefry"`. For
231  example:
232
233  >>> g = tf.random.Generator.from_seed(123, alg="philox")
234  >>> g.normal(shape=(2, 3))
235  <tf.Tensor: shape=(2, 3), dtype=float32, numpy=
236  array([[ 0.8673864 , -0.29899067, -0.9310337 ],
237         [-1.5828488 ,  1.2481191 , -0.6770643 ]], dtype=float32)>
238
239  CPU, GPU and TPU with the same algorithm and seed will generate the same
240  integer random numbers. Float-point results (such as the output of `normal`)
241  may have small numerical discrepancies between different devices.
242
243  This class uses a `tf.Variable` to manage its internal state. Every time
244  random numbers are generated, the state of the generator will change. For
245  example:
246
247  >>> g = tf.random.Generator.from_seed(1234)
248  >>> g.state
249  <tf.Variable ... numpy=array([1234,    0,    0])>
250  >>> g.normal(shape=(2, 3))
251  <...>
252  >>> g.state
253  <tf.Variable ... numpy=array([2770,    0,    0])>
254
255  The shape of the state is algorithm-specific.
256
257  There is also a global generator:
258
259  >>> g = tf.random.get_global_generator()
260  >>> g.normal(shape=(2, 3))
261  <tf.Tensor: shape=(2, 3), dtype=float32, numpy=...>
262
263  When creating a generator inside a `tf.distribute.Strategy` scope, each
264  replica will get a different stream of random numbers.
265
266  For example, in this code:
267
268  ```
269  strat = tf.distribute.MirroredStrategy(devices=["cpu:0", "cpu:1"])
270  with strat.scope():
271    g = tf.random.Generator.from_seed(1)
272    def f():
273      return g.normal([])
274    results = strat.run(f).values
275  ```
276
277  `results[0]` and `results[1]` will have different values.
278
279  If the generator is seeded (e.g. created via `Generator.from_seed`), the
280  random numbers will be determined by the seed, even though different replicas
281  get different numbers.  One can think of a random number generated on a
282  replica as a hash of the replica ID and a "master" random number that may be
283  common to all replicas. Hence, the whole system is still deterministic.
284
285  (Note that the random numbers on different replicas are not correlated, even
286  if they are deterministically determined by the same seed. They are not
287  correlated in the sense that no matter what statistics one calculates on them,
288  there won't be any discernable correlation.)
289
290  Generators can be freely saved and restored using `tf.train.Checkpoint`. The
291  checkpoint can be restored in a distribution strategy with a different number
292  of replicas than the original strategy. If a replica ID is present in both the
293  original and the new distribution strategy, its state will be properly
294  restored (i.e. the random-number stream from the restored point will be the
295  same as that from the saving point) unless the replicas have already diverged
296  in their RNG call traces before saving (e.g. one replica has made one RNG call
297  while another has made two RNG calls). We don't have such guarantee if the
298  generator is saved in a strategy scope and restored outside of any strategy
299  scope, or vice versa.
300
301  When a generator is created within the scope of
302  `tf.distribute.experimental.ParameterServerStrategy`, the workers
303  will share the generator's state (placed on one of the parameter
304  servers). In this way the workers will still get different
305  random-number streams, as stated above. (This is similar to replicas
306  in a `tf.distribute.MirroredStrategy` sequentially accessing a
307  generator created outside the strategy.) Each RNG call on a worker
308  will incur a round-trip to a parameter server, which may have
309  performance impacts. When creating a
310  `tf.distribute.experimental.ParameterServerStrategy`, please make
311  sure that the `variable_partitioner` argument won't shard small
312  variables of shape `[2]` or `[3]` (because generator states must not
313  be sharded). Ways to avoid sharding small variables include setting
314  `variable_partitioner` to `None` or to
315  `tf.distribute.experimental.partitioners.MinSizePartitioner` with a
316  large enough `min_shard_bytes` (see
317  `tf.distribute.experimental.ParameterServerStrategy`'s documentation
318  for more details).
319  """
320
321  @classmethod
322  def from_state(cls, state, alg):
323    """Creates a generator from a state.
324
325    See `__init__` for description of `state` and `alg`.
326
327    Args:
328      state: the new state.
329      alg: the RNG algorithm.
330
331    Returns:
332      The new generator.
333    """
334    return cls(alg=alg, state=state)
335
336  @classmethod
337  def from_seed(cls, seed, alg=None):
338    """Creates a generator from a seed.
339
340    A seed is a 1024-bit unsigned integer represented either as a Python
341    integer or a vector of integers. Seeds shorter than 1024-bit will be
342    padded. The padding, the internal structure of a seed and the way a seed
343    is converted to a state are all opaque (unspecified). The only semantics
344    specification of seeds is that two different seeds are likely to produce
345    two independent generators (but no guarantee).
346
347    Args:
348      seed: the seed for the RNG.
349      alg: (optional) the RNG algorithm. If None, it will be auto-selected. See
350        `__init__` for its possible values.
351
352    Returns:
353      The new generator.
354    """
355    if alg is None:
356      # TODO(b/170668986): more sophisticated algorithm selection
357      alg = DEFAULT_ALGORITHM
358    alg = stateless_random_ops.convert_alg_to_int(alg)
359    state = create_rng_state(seed, alg)
360    return cls(state=state, alg=alg)
361
362  @classmethod
363  def from_non_deterministic_state(cls, alg=None):
364    """Creates a generator by non-deterministically initializing its state.
365
366    The source of the non-determinism will be platform- and time-dependent.
367
368    Args:
369      alg: (optional) the RNG algorithm. If None, it will be auto-selected. See
370        `__init__` for its possible values.
371
372    Returns:
373      The new generator.
374    """
375    if config.is_op_determinism_enabled():
376      raise RuntimeError('"from_non_deterministic_state" cannot be called when '  # pylint: disable=g-doc-exception
377                         "determinism is enabled.")
378    if alg is None:
379      # TODO(b/170668986): more sophisticated algorithm selection
380      alg = DEFAULT_ALGORITHM
381    alg = stateless_random_ops.convert_alg_to_int(alg)
382    state = non_deterministic_ints(shape=[_get_state_size(alg)],
383                                   dtype=SEED_TYPE)
384    return cls(state=state, alg=alg)
385
386  @classmethod
387  def from_key_counter(cls, key, counter, alg):
388    """Creates a generator from a key and a counter.
389
390    This constructor only applies if the algorithm is a counter-based algorithm.
391    See method `key` for the meaning of "key" and "counter".
392
393    Args:
394      key: the key for the RNG, a scalar of type STATE_TYPE.
395      counter: a vector of dtype STATE_TYPE representing the initial counter for
396        the RNG, whose length is algorithm-specific.,
397      alg: the RNG algorithm. If None, it will be auto-selected. See
398        `__init__` for its possible values.
399
400    Returns:
401      The new generator.
402    """
403    counter = _convert_to_state_tensor(counter)
404    key = _convert_to_state_tensor(key)
405    alg = stateless_random_ops.convert_alg_to_int(alg)
406    counter.shape.assert_is_compatible_with([_get_state_size(alg) - 1])
407    key.shape.assert_is_compatible_with([])
408    key = array_ops.reshape(key, [1])
409    state = array_ops.concat([counter, key], 0)
410    return cls(state=state, alg=alg)
411
412  def __init__(self, copy_from=None, state=None, alg=None):
413    """Creates a generator.
414
415    The new generator will be initialized by one of the following ways, with
416    decreasing precedence:
417    (1) If `copy_from` is not None, the new generator is initialized by copying
418        information from another generator.
419    (2) If `state` and `alg` are not None (they must be set together), the new
420        generator is initialized by a state.
421
422    Args:
423      copy_from: a generator to be copied from.
424      state: a vector of dtype STATE_TYPE representing the initial state of the
425        RNG, whose length and semantics are algorithm-specific. If it's a
426        variable, the generator will reuse it instead of creating a new
427        variable.
428      alg: the RNG algorithm. Possible values are
429        `tf.random.Algorithm.PHILOX` for the Philox algorithm and
430        `tf.random.Algorithm.THREEFRY` for the ThreeFry algorithm
431        (see paper 'Parallel Random Numbers: As Easy as 1, 2, 3'
432        [https://www.thesalmons.org/john/random123/papers/random123sc11.pdf]).
433        The string names `"philox"` and `"threefry"` can also be used.
434        Note `PHILOX` guarantees the same numbers are produced (given
435        the same random state) across all architectures (CPU, GPU, XLA etc).
436    """
437    # TODO(b/175072242): Remove distribution-strategy dependencies in this file.
438    if ds_context.has_strategy():
439      self._distribution_strategy = ds_context.get_strategy()
440    else:
441      self._distribution_strategy = None
442    if copy_from is not None:
443      # All other arguments should be None
444      assert (alg or state) is None
445      self._state_var = self._create_variable(copy_from.state, dtype=STATE_TYPE,
446                                              trainable=False)
447      self._alg = copy_from.algorithm
448    else:
449      assert alg is not None and state is not None
450      alg = stateless_random_ops.convert_alg_to_int(alg)
451      if isinstance(state, variables.Variable):
452        _check_state_shape(state.shape, alg)
453        self._state_var = state
454      else:
455        state = _convert_to_state_tensor(state)
456        _check_state_shape(state.shape, alg)
457        self._state_var = self._create_variable(state, dtype=STATE_TYPE,
458                                                trainable=False)
459      self._alg = alg
460
461  def _create_variable(self, *args, **kwargs):
462    """Creates a variable.
463
464    Args:
465      *args: positional arguments passed along to `variables.Variable.
466      **kwargs: keyword arguments passed along to `variables.Variable.
467
468    Returns:
469      The created variable.
470    """
471    with ops.name_scope("random_generator"):
472      # Make sure we don't change this name since Keras was using this name
473      # to filter out the state variable.
474      kwargs["name"] = "StateVar"
475      v = variables.Variable(*args, **kwargs)
476    if isinstance(v, sharded_variable.ShardedVariable):
477      # RNG state is an atomic entity representing a 128-bit or
478      # 192-bit value, so it mustn't be sharded.
479      raise ValueError(
480          "tf.random.Generator state is sharded, which is not allowed. When "
481          "creating a tf.distribute.experimental.ParameterServerStrategy, "
482          "please make sure that the `variable_partitioner` "
483          "argument won't shard a "
484          "small variable of shape [2] or [3]. Ways to avoid sharding small "
485          "variables include setting `variable_partitioner` to None or to "
486          "tf.distribute.experimental.partitioners.MinSizePartitioner with a "
487          "large enough `min_shard_bytes`.")
488    return v
489
490  def reset(self, state):
491    """Resets the generator by a new state.
492
493    See `__init__` for the meaning of "state".
494
495    Args:
496      state: the new state.
497    """
498    state = _convert_to_state_tensor(state)
499    state.shape.assert_is_compatible_with([_get_state_size(self.algorithm)])
500    self._state_var.assign(state)
501
502  def reset_from_seed(self, seed):
503    """Resets the generator by a new seed.
504
505    See `from_seed` for the meaning of "seed".
506
507    Args:
508      seed: the new seed.
509    """
510    state = create_rng_state(seed, self.algorithm)
511    self._state_var.assign(state)
512
513  def reset_from_key_counter(self, key, counter):
514    """Resets the generator by a new key-counter pair.
515
516    See `from_key_counter` for the meaning of "key" and "counter".
517
518    Args:
519      key: the new key.
520      counter: the new counter.
521    """
522    counter = _convert_to_state_tensor(counter)
523    key = _convert_to_state_tensor(key)
524    counter.shape.assert_is_compatible_with(
525        [_get_state_size(self.algorithm) - 1])
526    key.shape.assert_is_compatible_with([])
527    key = array_ops.reshape(key, [1])
528    state = array_ops.concat([counter, key], 0)
529    self._state_var.assign(state)
530
531  @property
532  def state(self):
533    """The internal state of the RNG."""
534    return self._state_var
535
536  @property
537  def algorithm(self):
538    """The RNG algorithm id (a Python integer or scalar integer Tensor)."""
539    return self._alg
540
541  def _standard_normal(self, shape, dtype):
542    key, counter = self._prepare_key_counter(shape)
543    return gen_stateless_random_ops_v2.stateless_random_normal_v2(
544        shape, key=key, counter=counter, dtype=dtype, alg=self.algorithm)
545
546  @property
547  def key(self):
548    """The 'key' part of the state of a counter-based RNG.
549
550    For a counter-base RNG algorithm such as Philox and ThreeFry (as
551    described in paper 'Parallel Random Numbers: As Easy as 1, 2, 3'
552    [https://www.thesalmons.org/john/random123/papers/random123sc11.pdf]),
553    the RNG state consists of two parts: counter and key. The output is
554    generated via the formula: output=hash(key, counter), i.e. a hashing of
555    the counter parametrized by the key. Two RNGs with two different keys can
556    be thought as generating two independent random-number streams (a stream
557    is formed by increasing the counter).
558
559    Returns:
560      A scalar which is the 'key' part of the state, if the RNG algorithm is
561        counter-based; otherwise it raises a ValueError.
562    """
563    alg = self.algorithm
564    if alg == RNG_ALG_PHILOX or alg == RNG_ALG_THREEFRY:
565      return self._state_var[-1]
566    else:
567      raise ValueError(
568          f"This generator uses an unsupported algorithm {alg}. Supported "
569          f"values are {RNG_ALG_PHILOX} for the Philox algorithm and "
570          f"{RNG_ALG_THREEFRY} for the ThreeFry algorithm.")
571
572  def _skip_single_var(self, var, delta):
573    resource_variable_ops.variable_accessed(var)
574    # TODO(wangpeng): Cache the cast algorithm instead of casting everytime.
575    return gen_stateful_random_ops.rng_read_and_skip(
576        var.handle,
577        alg=math_ops.cast(self.algorithm, dtypes.int32),
578        delta=math_ops.cast(delta, dtypes.uint64))
579
580  def skip(self, delta):
581    """Advance the counter of a counter-based RNG.
582
583    Args:
584      delta: the amount of advancement. The state of the RNG after
585        `skip(n)` will be the same as that after `normal([n])`
586        (or any other distribution). The actual increment added to the
587        counter is an unspecified implementation detail.
588
589    Returns:
590      A `Tensor` of type `int64`.
591    """
592
593    def update_fn(v):
594      return self._skip_single_var(v, delta)
595    # TODO(b/170515001): Always call strategy.extended.update after calling it
596    #   from both replica context and cross-replica context is supported.
597    if values_util.is_saving_non_distributed():
598      # Assumes replica context with replica_id=0, since we only save the first
599      # replica.
600      return update_fn(self.state)
601    if self._distribution_strategy is not None:
602      with ds_context.enter_or_assert_strategy(self._distribution_strategy):
603        if ds_context.in_cross_replica_context():
604          # Code that operates on all replicas of a variable cannot be saved
605          # without retracing.
606          values_util.mark_as_unsaveable()
607        if (ds_context.in_cross_replica_context() or
608            "CentralStorage" in type(self._distribution_strategy).__name__):
609          # In cross-replica context we need to use strategy.extended.update.
610          # In CentralStorageStrategy we also need to use
611          # strategy.extended.update (even for replica context),
612          # because variable updates here must be within merge_call.
613          return ds_context.get_strategy().extended.update(
614              self.state, update_fn)
615    return update_fn(self.state)
616
617  def _preprocess_key(self, key):
618    if self._distribution_strategy is None:
619      return key
620    with ds_context.enter_or_assert_strategy(self._distribution_strategy):
621      replica_id = get_replica_id()
622      if replica_id is not None:
623        replica_id = array_ops.stack([replica_id, 0], axis=0)
624        replica_id = math_ops.cast(replica_id, dtypes.uint64)
625        # Conceptually: key = hash(key, replica_id)
626        key = gen_stateless_random_ops_v2.stateless_random_uniform_full_int_v2(
627            shape=[1], key=key, counter=replica_id, dtype=dtypes.uint64,
628            alg=self.algorithm)
629      return key
630
631  def _prepare_key_counter(self, shape):
632    delta = math_ops.reduce_prod(shape)
633    counter_key = self.skip(delta)
634    counter_size = _get_counter_size(self.algorithm)
635    counter = array_ops.bitcast(counter_key[:counter_size], dtypes.uint64)
636    key = array_ops.bitcast(counter_key[counter_size:counter_size + 1],
637                            dtypes.uint64)
638    key = self._preprocess_key(key)
639    return key, counter
640
641  # The following functions return a tensor and as a side effect update
642  # self._state_var.
643  def normal(self, shape, mean=0.0, stddev=1.0, dtype=dtypes.float32,
644             name=None):
645    """Outputs random values from a normal distribution.
646
647    Args:
648      shape: A 1-D integer Tensor or Python array. The shape of the output
649        tensor.
650      mean: A 0-D Tensor or Python value of type `dtype`. The mean of the normal
651        distribution.
652      stddev: A 0-D Tensor or Python value of type `dtype`. The standard
653        deviation of the normal distribution.
654      dtype: The type of the output.
655      name: A name for the operation (optional).
656
657    Returns:
658      A tensor of the specified shape filled with random normal values.
659    """
660    with ops.name_scope(name, "stateful_normal", [shape, mean, stddev]) as name:
661      shape = _shape_tensor(shape)
662      mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
663      stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
664      rnd = self._standard_normal(shape, dtype=dtype)
665      return math_ops.add(rnd * stddev, mean, name=name)
666
667  def _truncated_normal(self, shape, dtype):
668    key, counter = self._prepare_key_counter(shape)
669    return gen_stateless_random_ops_v2.stateless_truncated_normal_v2(
670        shape=shape, key=key, counter=counter, dtype=dtype, alg=self.algorithm)
671
672  def truncated_normal(self, shape,
673                       mean=0.0,
674                       stddev=1.0,
675                       dtype=dtypes.float32,
676                       name=None):
677    """Outputs random values from a truncated normal distribution.
678
679    The generated values follow a normal distribution with specified mean and
680    standard deviation, except that values whose magnitude is more than
681    2 standard deviations from the mean are dropped and re-picked.
682
683    Args:
684      shape: A 1-D integer Tensor or Python array. The shape of the output
685        tensor.
686      mean: A 0-D Tensor or Python value of type `dtype`. The mean of the
687        truncated normal distribution.
688      stddev: A 0-D Tensor or Python value of type `dtype`. The standard
689        deviation of the normal distribution, before truncation.
690      dtype: The type of the output.
691      name: A name for the operation (optional).
692
693    Returns:
694      A tensor of the specified shape filled with random truncated normal
695        values.
696    """
697    with ops.name_scope(
698        name, "truncated_normal", [shape, mean, stddev]) as name:
699      shape_tensor = _shape_tensor(shape)
700      mean_tensor = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
701      stddev_tensor = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
702      rnd = self._truncated_normal(shape_tensor, dtype=dtype)
703      mul = rnd * stddev_tensor
704      return math_ops.add(mul, mean_tensor, name=name)
705
706  def _uniform(self, shape, dtype):
707    key, counter = self._prepare_key_counter(shape)
708    return gen_stateless_random_ops_v2.stateless_random_uniform_v2(
709        shape=shape, key=key, counter=counter, dtype=dtype, alg=self.algorithm)
710
711  def _uniform_full_int(self, shape, dtype, name=None):
712    key, counter = self._prepare_key_counter(shape)
713    return gen_stateless_random_ops_v2.stateless_random_uniform_full_int_v2(
714        shape=shape,
715        key=key,
716        counter=counter,
717        dtype=dtype,
718        alg=self.algorithm,
719        name=name)
720
721  def uniform(self, shape, minval=0, maxval=None,
722              dtype=dtypes.float32, name=None):
723    """Outputs random values from a uniform distribution.
724
725    The generated values follow a uniform distribution in the range
726    `[minval, maxval)`. The lower bound `minval` is included in the range, while
727    the upper bound `maxval` is excluded. (For float numbers especially
728    low-precision types like bfloat16, because of
729    rounding, the result may sometimes include `maxval`.)
730
731    For floats, the default range is `[0, 1)`.  For ints, at least `maxval` must
732    be specified explicitly.
733
734    In the integer case, the random integers are slightly biased unless
735    `maxval - minval` is an exact power of two.  The bias is small for values of
736    `maxval - minval` significantly smaller than the range of the output (either
737    `2**32` or `2**64`).
738
739    For full-range random integers, pass `minval=None` and `maxval=None` with an
740    integer `dtype` (for integer dtypes, `minval` and `maxval` must be both
741    `None` or both not `None`).
742
743    Args:
744      shape: A 1-D integer Tensor or Python array. The shape of the output
745        tensor.
746      minval: A Tensor or Python value of type `dtype`, broadcastable with
747        `shape` (for integer types, broadcasting is not supported, so it needs
748        to be a scalar). The lower bound (included) on the range of random
749        values to generate. Pass `None` for full-range integers. Defaults to 0.
750      maxval: A Tensor or Python value of type `dtype`, broadcastable with
751        `shape` (for integer types, broadcasting is not supported, so it needs
752        to be a scalar). The upper bound (excluded) on the range of random
753        values to generate. Pass `None` for full-range integers. Defaults to 1
754        if `dtype` is floating point.
755      dtype: The type of the output.
756      name: A name for the operation (optional).
757
758    Returns:
759      A tensor of the specified shape filled with random uniform values.
760
761    Raises:
762      ValueError: If `dtype` is integral and `maxval` is not specified.
763    """
764    dtype = dtypes.as_dtype(dtype)
765    if dtype.is_integer:
766      if (minval is None) != (maxval is None):
767        raise ValueError("For integer dtype {}, minval and maxval must be both "
768                         "`None` or both non-`None`; got minval={} and "
769                         "maxval={}".format(dtype, minval, maxval))
770    elif maxval is None:
771      maxval = 1
772    with ops.name_scope(name, "stateful_uniform",
773                        [shape, minval, maxval]) as name:
774      shape = _shape_tensor(shape)
775      if dtype.is_integer and minval is None:
776        return self._uniform_full_int(shape=shape, dtype=dtype, name=name)
777      minval = ops.convert_to_tensor(minval, dtype=dtype, name="min")
778      maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
779      if dtype.is_integer:
780        key, counter = self._prepare_key_counter(shape)
781        return gen_stateless_random_ops_v2.stateless_random_uniform_int_v2(
782            shape=shape,
783            key=key,
784            counter=counter,
785            minval=minval,
786            maxval=maxval,
787            alg=self.algorithm,
788            name=name)
789      else:
790        rnd = self._uniform(shape=shape, dtype=dtype)
791        return math_ops.add(rnd * (maxval - minval), minval, name=name)
792
793  def uniform_full_int(self, shape, dtype=dtypes.uint64, name=None):
794    """Uniform distribution on an integer type's entire range.
795
796    This method is the same as setting `minval` and `maxval` to `None` in the
797    `uniform` method.
798
799    Args:
800      shape: the shape of the output.
801      dtype: (optional) the integer type, default to uint64.
802      name: (optional) the name of the node.
803
804    Returns:
805      A tensor of random numbers of the required shape.
806    """
807    dtype = dtypes.as_dtype(dtype)
808    with ops.name_scope(name, "stateful_uniform_full_int",
809                        [shape]) as name:
810      shape = _shape_tensor(shape)
811      return self._uniform_full_int(shape=shape, dtype=dtype, name=name)
812
813  def binomial(self, shape, counts, probs, dtype=dtypes.int32, name=None):
814    """Outputs random values from a binomial distribution.
815
816    The generated values follow a binomial distribution with specified count and
817    probability of success parameters.
818
819    Example:
820
821    ```python
822    counts = [10., 20.]
823    # Probability of success.
824    probs = [0.8]
825
826    rng = tf.random.Generator.from_seed(seed=234)
827    binomial_samples = rng.binomial(shape=[2], counts=counts, probs=probs)
828
829
830    counts = ... # Shape [3, 1, 2]
831    probs = ...  # Shape [1, 4, 2]
832    shape = [3, 4, 3, 4, 2]
833    rng = tf.random.Generator.from_seed(seed=1717)
834    # Sample shape will be [3, 4, 3, 4, 2]
835    binomial_samples = rng.binomial(shape=shape, counts=counts, probs=probs)
836    ```
837
838
839    Args:
840      shape: A 1-D integer Tensor or Python array. The shape of the output
841        tensor.
842      counts: Tensor. The counts of the binomial distribution. Must be
843        broadcastable with `probs`, and broadcastable with the rightmost
844        dimensions of `shape`.
845      probs: Tensor. The probability of success for the
846        binomial distribution. Must be broadcastable with `counts` and
847        broadcastable with the rightmost dimensions of `shape`.
848      dtype: The type of the output. Default: tf.int32
849      name: A name for the operation (optional).
850
851    Returns:
852      samples: A Tensor of the specified shape filled with random binomial
853        values.  For each i, each samples[i, ...] is an independent draw from
854        the binomial distribution on counts[i] trials with probability of
855        success probs[i].
856    """
857    dtype = dtypes.as_dtype(dtype)
858    with ops.name_scope(name, "binomial", [shape, counts, probs]) as name:
859      counts = ops.convert_to_tensor(counts, name="counts")
860      probs = ops.convert_to_tensor(probs, name="probs")
861      shape_tensor = _shape_tensor(shape)
862      return gen_stateful_random_ops.stateful_random_binomial(
863          self.state.handle,
864          self.algorithm,
865          shape=shape_tensor,
866          counts=counts,
867          probs=probs,
868          dtype=dtype,
869          name=name)
870
871  # TODO(wangpeng): implement other distributions
872
873  def _make_int64_keys(self, shape=()):
874    # New independent keys are generated via
875    # `new_key[i] = hash(old_key, counter+i)`, which is exactly what
876    # `uniform_full_int(dtype=int64)` does for PhiloxRandom_64_128_128 and
877    # ThreeFry_64_64_64.
878    return self.uniform_full_int(shape=shape, dtype=dtypes.int64)
879
880  def make_seeds(self, count=1):
881    """Generates seeds for stateless random ops.
882
883    For example:
884
885    ```python
886    seeds = get_global_generator().make_seeds(count=10)
887    for i in range(10):
888      seed = seeds[:, i]
889      numbers = stateless_random_normal(shape=[2, 3], seed=seed)
890      ...
891    ```
892
893    Args:
894      count: the number of seed pairs (note that stateless random ops need a
895        pair of seeds to invoke).
896
897    Returns:
898      A tensor of shape [2, count] and dtype int64.
899    """
900    alg = self.algorithm
901    if alg == RNG_ALG_PHILOX or alg == RNG_ALG_THREEFRY:
902      keys = self._make_int64_keys(shape=[count])
903      # The two seeds for stateless random ops don't have individual semantics
904      # and are scrambled together, so setting one to zero is fine.
905      zeros = array_ops.zeros_like(keys)
906      return array_ops.stack([keys, zeros])
907    else:
908      raise ValueError(
909          f"This generator uses an unsupported algorithm {alg}. Supported "
910          f"values are {RNG_ALG_PHILOX} for the Philox algorithm and "
911          f"{RNG_ALG_THREEFRY} for the ThreeFry algorithm.")
912
913  def split(self, count=1):
914    """Returns a list of independent `Generator` objects.
915
916    Two generators are independent of each other in the sense that the
917    random-number streams they generate don't have statistically detectable
918    correlations. The new generators are also independent of the old one.
919    The old generator's state will be changed (like other random-number
920    generating methods), so two calls of `split` will return different
921    new generators.
922
923    For example:
924
925    ```python
926    gens = get_global_generator().split(count=10)
927    for gen in gens:
928      numbers = gen.normal(shape=[2, 3])
929      # ...
930    gens2 = get_global_generator().split(count=10)
931    # gens2 will be different from gens
932    ```
933
934    The new generators will be put on the current device (possible different
935    from the old generator's), for example:
936
937    ```python
938    with tf.device("/device:CPU:0"):
939      gen = Generator(seed=1234)  # gen is on CPU
940    with tf.device("/device:GPU:0"):
941      gens = gen.split(count=10)  # gens are on GPU
942    ```
943
944    Args:
945      count: the number of generators to return.
946
947    Returns:
948      A list (length `count`) of `Generator` objects independent of each other.
949      The new generators have the same RNG algorithm as the old one.
950    """
951    def _key_to_state(alg, key):
952      # Padding with zeros on the left. The zeros will be the counter.
953      return [0] * (_get_state_size(alg) - 1) + [key]
954
955    alg = self.algorithm
956    if alg == RNG_ALG_PHILOX or alg == RNG_ALG_THREEFRY:
957      keys = self._make_int64_keys(shape=[count])
958      return [Generator(state=_key_to_state(alg, key), alg=alg)
959              for key in array_ops.unstack(keys, num=count)]
960    else:
961      raise ValueError(
962          f"This generator uses an unsupported algorithm {alg}. Supported "
963          f"values are {RNG_ALG_PHILOX} for the Philox algorithm and "
964          f"{RNG_ALG_THREEFRY} for the ThreeFry algorithm.")
965
966
967# It's not safe to create TF ops before `init_google` is called, so this is
968# initialized to None and get a value the first time `get_global_generator` is
969# called.
970global_generator = None
971
972
973@tf_export("random.get_global_generator",
974           "random.experimental.get_global_generator")
975def get_global_generator():
976  """Retrieves the global generator.
977
978  This function will create the global generator the first time it is called,
979  and the generator will be placed at the default device at that time, so one
980  needs to be careful when this function is first called. Using a generator
981  placed on a less-ideal device will incur performance regression.
982
983  Returns:
984    The global `tf.random.Generator` object.
985  """
986  global global_generator
987  if global_generator is None:
988    if config.is_op_determinism_enabled():
989      raise RuntimeError('"get_global_generator" cannot be called if '  # pylint: disable=g-doc-exception
990                         "determinism is enabled, unless "
991                         '"set_global_generator" has already been called. '
992                         'Please call "set_global_generator" first.')
993    with ops.init_scope():
994      global_generator = Generator.from_non_deterministic_state()
995  return global_generator
996
997
998@tf_export("random.set_global_generator",
999           "random.experimental.set_global_generator")
1000def set_global_generator(generator):
1001  """Replaces the global generator with another `Generator` object.
1002
1003  This function replaces the global generator with the provided `generator`
1004  object.
1005  A random number generator utilizes a `tf.Variable` object to store its state.
1006  The user shall be aware of caveats how `set_global_generator` interacts with
1007  `tf.function`:
1008
1009  - tf.function puts restrictions on Variable creation thus one cannot freely
1010    create a new random generator instance inside `tf.function`.
1011    To call `set_global_generator` inside `tf.function`, the generator instance
1012    must have already been created eagerly.
1013  - tf.function captures the Variable during trace-compilation, thus a compiled
1014    f.function will not be affected `set_global_generator` as demonstrated by
1015    random_test.py/RandomTest.testResetGlobalGeneratorBadWithDefun .
1016
1017  For most use cases, avoid calling `set_global_generator` after program
1018  initialization, and prefer to reset the state of the existing global generator
1019  instead, such as,
1020
1021  >>> rng = tf.random.get_global_generator()
1022  >>> rng.reset_from_seed(30)
1023
1024
1025  Args:
1026    generator: the new `Generator` object.
1027  """
1028  global global_generator
1029  global_generator = generator
1030