xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/init_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Operations often used for initializing tensors.
16
17All variable initializers returned by functions in this file should have the
18following signature:
19
20def _initializer(shape, dtype=dtypes.float32, partition_info=None):
21  Args:
22    shape: List of `int` representing the shape of the output `Tensor`. Some
23      initializers may also be able to accept a `Tensor`.
24    dtype: (Optional) Type of the output `Tensor`.
25    partition_info: (Optional) variable_scope._PartitionInfo object holding
26      additional information about how the variable is partitioned. May be
27      `None` if the variable is not partitioned.
28
29  Returns:
30    A `Tensor` of type `dtype` and `shape`.
31"""
32import math
33
34import numpy as np
35
36from tensorflow.python.framework import constant_op
37from tensorflow.python.framework import dtypes
38from tensorflow.python.framework import tensor_shape
39from tensorflow.python.ops import array_ops
40from tensorflow.python.ops import gen_linalg_ops
41from tensorflow.python.ops import linalg_ops_impl
42from tensorflow.python.ops import math_ops
43from tensorflow.python.ops import random_ops
44from tensorflow.python.util import deprecation
45from tensorflow.python.util.deprecation import deprecated
46from tensorflow.python.util.deprecation import deprecated_arg_values
47from tensorflow.python.util.deprecation import deprecated_args
48from tensorflow.python.util.tf_export import tf_export
49
50
51class Initializer:
52  """Initializer base class: all initializers inherit from this class."""
53
54  def __call__(self, shape, dtype=None, partition_info=None):
55    """Returns a tensor object initialized as specified by the initializer.
56
57    Args:
58      shape: Shape of the tensor.
59      dtype: Optional dtype of the tensor. If not provided use the initializer
60        dtype.
61      partition_info: Optional information about the possible partitioning of a
62        tensor.
63    """
64    raise NotImplementedError
65
66  def get_config(self):
67    """Returns the configuration of the initializer as a JSON-serializable dict.
68
69    Returns:
70      A JSON-serializable Python dict.
71    """
72    return {}
73
74  @classmethod
75  def from_config(cls, config):
76    """Instantiates an initializer from a configuration dictionary.
77
78    Example:
79
80    ```python
81    initializer = RandomUniform(-1, 1)
82    config = initializer.get_config()
83    initializer = RandomUniform.from_config(config)
84    ```
85
86    Args:
87      config: A Python dictionary. It will typically be the output of
88        `get_config`.
89
90    Returns:
91      An Initializer instance.
92    """
93    return cls(**config)
94
95
96@tf_export(v1=["initializers.zeros", "zeros_initializer"])
97@deprecation.deprecated_endpoints("initializers.zeros")
98class Zeros(Initializer):
99  """Initializer that generates tensors initialized to 0.
100
101  @compatibility(TF2)
102  `tf.compat.v1.zeros_initializer` is compatible with eager execution
103  and `tf.function`.
104
105  To migrate to TF2, please use `tf.zeros_initializer` instead. The `dtype`
106  argument in `tf.compat.v1.zeros_initializer.__init__()` does not exist in
107  `tf.zeros_initializer.__init__()`. However, you can specify the `dtype` in
108  `__call__()` in both cases.
109
110  #### Structural Mapping to TF2
111
112  Before:
113
114  ```python
115  initializer = tf.compat.v1.zeros_initializer(dtype=tf.float32)
116  variable = tf.Variable(initializer(shape=[3, 3]))
117  ```
118
119  After:
120
121  ```python
122  initializer = tf.zeros_initializer()
123  variable = tf.Variable(initializer(shape=[3, 3], dtype=tf.float32))
124  ```
125
126  #### How to Map Arguments
127
128  | TF1 Arg Name         | TF2 Arg Name     | Note                       |
129  | :------------------- | :--------------- | :------------------------- |
130  | `dtype`              | `dtype`          | In `__call__()` method     |
131  | `partition_info`     | - |  (`__call__` arg in TF1) Not supported    |
132
133
134  #### Before & After Usage Example
135
136  Before:
137
138  >>> initializer = tf.compat.v1.zeros_initializer(dtype=tf.float32)
139  >>> tf.Variable(initializer(shape=[3])).numpy()
140  array([0., 0., 0.], dtype=float32)
141  >>> tf.Variable(initializer(shape=[3, 3])).numpy()
142  array([[0., 0., 0.],
143         [0., 0., 0.],
144         [0., 0., 0.]], dtype=float32)
145  >>> initializer = tf.compat.v1.zeros_initializer()
146  >>> tf.Variable(initializer(shape=[3], dtype=tf.float32)).numpy()
147  array([0., 0., 0.], dtype=float32)
148  >>> tf.Variable(initializer(shape=[3, 3], dtype=tf.float32)).numpy()
149  array([[0., 0., 0.],
150         [0., 0., 0.],
151         [0., 0., 0.]], dtype=float32)
152
153  After:
154
155  >>> initializer = tf.zeros_initializer()
156  >>> tf.Variable(initializer(shape=[3], dtype=tf.float32)).numpy()
157  array([0., 0., 0.], dtype=float32)
158  >>> tf.Variable(initializer(shape=[3, 3], dtype=tf.float32)).numpy()
159  array([[0., 0., 0.],
160         [0., 0., 0.],
161         [0., 0., 0.]], dtype=float32)
162
163  @end_compatibility
164  """
165
166  @deprecated_args(None,
167                   "Call initializer instance with the dtype argument instead "
168                   "of passing it to the constructor", "dtype")
169  def __init__(self, dtype=dtypes.float32):
170    self.dtype = dtypes.as_dtype(dtype)
171
172  def __call__(self, shape, dtype=None, partition_info=None):
173    if dtype is None:
174      dtype = self.dtype
175    return array_ops.zeros(shape, dtype)
176
177  def get_config(self):
178    return {"dtype": self.dtype.name}
179
180
181@tf_export(v1=["initializers.ones", "ones_initializer"])
182@deprecation.deprecated_endpoints("initializers.ones", "ones_initializer")
183class Ones(Initializer):
184  """Initializer that generates tensors initialized to 1.
185
186  @compatibility(TF2)
187  This API is compatible with TF2 behavior and `tf.function`, and can be
188  migrated immediately with `tf.keras.initializers.ones`.
189
190  Before:
191  >>> initializer = tf.compat.v1.keras.initializers.ones()
192  >>> initializer((1, 1))
193  <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[1.]], dtype=float32)>
194
195  After:
196  >>> initializer = tf.keras.initializers.ones()
197  >>> initializer((1, 1))
198  <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[1.]], dtype=float32)>
199
200  @end_compatibility
201  """
202
203  @deprecated_args(None,
204                   "Call initializer instance with the dtype argument instead "
205                   "of passing it to the constructor", "dtype")
206  def __init__(self, dtype=dtypes.float32):
207    self.dtype = dtypes.as_dtype(dtype)
208
209  def __call__(self, shape, dtype=None, partition_info=None):
210    if dtype is None:
211      dtype = self.dtype
212    return array_ops.ones(shape, dtype)
213
214  def get_config(self):
215    return {"dtype": self.dtype.name}
216
217
218@tf_export(v1=["initializers.constant", "constant_initializer"])
219@deprecation.deprecated_endpoints("constant_initializer")
220class Constant(Initializer):
221  """Initializer that generates tensors with constant values.
222
223  The resulting tensor is populated with values of type `dtype`, as
224  specified by arguments `value` following the desired `shape` of the
225  new tensor (see examples below).
226
227  The argument `value` can be a constant value, or a list of values of type
228  `dtype`. If `value` is a list, then the length of the list must be less
229  than or equal to the number of elements implied by the desired shape of the
230  tensor. In the case where the total number of elements in `value` is less
231  than the number of elements required by the tensor shape, the last element
232  in `value` will be used to fill the remaining entries. If the total number of
233  elements in `value` is greater than the number of elements required by the
234  tensor shape, the initializer will raise a `ValueError`.
235
236  Args:
237    value: A Python scalar, list or tuple of values, or a N-dimensional numpy
238      array. All elements of the initialized variable will be set to the
239      corresponding value in the `value` argument.
240    dtype: Default data type, used if no `dtype` argument is provided when
241      calling the initializer.
242    verify_shape: Boolean that enables verification of the shape of `value`. If
243      `True`, the initializer will throw an error if the shape of `value` is not
244      compatible with the shape of the initialized tensor.
245
246  Raises:
247    TypeError: If the input `value` is not one of the expected types.
248
249  Examples:
250    The following example can be rewritten using a numpy.ndarray instead
251    of the `value` list, even reshaped, as shown in the two commented lines
252    below the `value` list initialization.
253
254  >>> value = [0, 1, 2, 3, 4, 5, 6, 7]
255  >>> init = tf.compat.v1.constant_initializer(value)
256  >>> # fitting shape
257  >>> with tf.compat.v1.Session():
258  ...   x = tf.compat.v1.get_variable('x', shape=[2, 4], initializer=init)
259  ...   x.initializer.run()
260  ...   print(x.eval())
261  [[0. 1. 2. 3.]
262   [4. 5. 6. 7.]]
263  >>> # Larger shape
264  >>> with tf.compat.v1.Session():
265  ...   y = tf.compat.v1.get_variable('y', shape=[3, 4], initializer=init)
266  ...   y.initializer.run()
267  ...   print(y.eval())
268  [[0.  1.  2.  3.]
269   [4.  5.  6.  7.]
270   [7.  7.  7.  7.]]
271  >>> # Smaller shape
272  >>> with tf.compat.v1.Session():
273  ...   z = tf.compat.v1.get_variable('z', shape=[2, 3], initializer=init)
274  Traceback (most recent call last):
275  ...
276  ValueError: Too many elements provided. Needed at most 6, but received 8
277  >>> # Shape verification
278  >>> init_verify = tf.compat.v1.constant_initializer(value, verify_shape=True)
279  >>> with tf.compat.v1.Session():
280  ...  u = tf.compat.v1.get_variable('u', shape=[3, 4],
281  ...                                initializer=init_verify)
282  Traceback (most recent call last):
283  ...
284  TypeError: Expected Tensor's shape: (3, 4), got (8,).
285
286  @compatibility(TF2)
287  Although it is a legacy API endpoint, `tf.compat.v1.constant_initializer`
288  is compatible with eager execution and `tf.function`.
289
290  To migrate to a non-legacy TF2 API, please use `tf.constant_initializer`
291  instead. The `dtype`
292  argument in `tf.compat.v1.constant_initializer.__init__()` does not exist in
293  `tf.constant_initializer.__init__()`. However, you can specify the `dtype` in
294  `__call__()` in both cases.
295
296  In the `compat.v1` symbol, if `verify_shape` is set to `True`, an exception
297  is raised when initializing a variable with a different shape from
298  `value`. If set to `False`, `value` is reshaped to initialize the variable
299  if necessary. An exception would only be raised when the number of
300  elements are different.
301
302  The `verify_shape` argument is not supported in TF2. Using
303  `tf.constant_initializer` is equivalent to setting `verify_shape` to `False`.
304
305  #### Structural Mapping to TF2
306
307  Before:
308
309  ```python
310  value = [0, 1, 2, 3, 4, 5, 6, 7]
311  initializer = tf.compat.v1.constant_initializer(
312      value=value,
313      dtype=tf.float32,
314      verify_shape=False)
315  variable = tf.Variable(initializer(shape=[2, 4]))
316  ```
317
318  After:
319
320  ```python
321  value = [0, 1, 2, 3, 4, 5, 6, 7]
322  initializer = tf.constant_initializer(value=value)
323  tf.Variable(initializer(shape=[2, 4], dtype=tf.float32))
324  ```
325
326  #### How to Map Arguments
327
328  | TF1 Arg Name          | TF2 Arg Name     | Note                        |
329  | :-------------------- | :--------------- | :-------------------------- |
330  | `value`               | `value`          | In constructor              |
331  | `dtype`               | `dtype`          | In `__call__()` method      |
332  | `verify_shape`        | Not Supported    | Equivalent to set to `False`|
333  | `partition_info`      | - |  (`__call__` arg in TF1) Not supported     |
334
335
336  #### Before & After Usage Example
337
338  Before:
339
340  >>> value = [1., 2., 3., 4.]
341  >>> initializer = tf.compat.v1.constant_initializer(
342  ...     value=value, dtype=tf.float32, verify_shape=True)
343  >>> tf.Variable(initializer(shape=[2, 2])).numpy()
344  Traceback (most recent call last):
345  ...
346  TypeError: Expected Tensor's shape: (2, 2), got (4,).
347  >>> initializer = tf.compat.v1.constant_initializer(
348  ...     value=value, dtype=tf.float32, verify_shape=False)
349  >>> tf.Variable(initializer(shape=[2, 2])).numpy()
350  array([[1., 2.],
351         [3., 4.]], dtype=float32)
352
353  After:
354
355  >>> value = [1., 2., 3., 4.]
356  >>> initializer = tf.constant_initializer(value=value)
357  >>> tf.Variable(initializer(shape=[2, 2], dtype=tf.float32)).numpy()
358  array([[1., 2.],
359         [3., 4.]], dtype=float32)
360
361  @end_compatibility
362  """
363
364  @deprecated_args(None,
365                   "Call initializer instance with the dtype argument instead "
366                   "of passing it to the constructor", "dtype")
367  @deprecated_args(None, "Objects must now be the required shape or no shape "
368                   "can be specified", "verify_shape")
369  def __init__(self, value=0, dtype=dtypes.float32, verify_shape=False):
370    if not (np.isscalar(value) or isinstance(value, (list, tuple, np.ndarray))):
371      raise TypeError(
372          f"Invalid type for initial value={value} of type: "
373          f"{type(value).__name__}. Expected Python scalar, list or tuple of "
374          "values, or numpy.ndarray.")
375
376    self.value = value
377    self.dtype = dtypes.as_dtype(dtype)
378    self._verify_shape = verify_shape
379
380  def __call__(self, shape, dtype=None, partition_info=None, verify_shape=None):
381    if dtype is None:
382      dtype = self.dtype
383    if verify_shape is None:
384      verify_shape = self._verify_shape
385    return constant_op.constant_v1(
386        self.value, dtype=dtype, shape=shape, verify_shape=verify_shape)
387
388  def get_config(self):
389    # We don't include `verify_shape` for compatibility with Keras.
390    # `verify_shape` should be passed as an argument to `__call__` rather
391    # than as a constructor argument: conceptually it isn't a property
392    # of the initializer.
393    return {"value": self.value, "dtype": self.dtype.name}
394
395
396@tf_export(v1=["initializers.random_uniform", "random_uniform_initializer"])
397@deprecation.deprecated_endpoints("initializers.random_uniform")
398class RandomUniform(Initializer):
399  """Initializer that generates tensors with a uniform distribution.
400
401  Args:
402    minval: A python scalar or a scalar tensor. Lower bound of the range of
403      random values to generate.
404    maxval: A python scalar or a scalar tensor. Upper bound of the range of
405      random values to generate.  Defaults to 1 for float types.
406    seed: A Python integer. Used to create random seeds. See
407      `tf.compat.v1.set_random_seed` for behavior.
408    dtype: Default data type, used if no `dtype` argument is provided when
409      calling the initializer.
410
411  @compatibility(TF2)
412  Although it is a legacy compat.v1 API, this symbol is compatible with eager
413  execution and `tf.function`.
414
415  To switch to TF2, switch to using either
416  `tf.initializers.RandomUniform` or `tf.keras.initializers.RandomUniform`
417  (neither from `compat.v1`) and
418  pass the dtype when calling the initializer. Keep in mind that
419  the default minval, maxval and the behavior of fixed seeds have changed.
420
421  #### Structural Mapping to TF2
422
423  Before:
424
425  ```python
426  initializer = tf.compat.v1.random_uniform_initializer(
427    minval=minval,
428    maxval=maxval,
429    seed=seed,
430    dtype=dtype)
431
432  weight_one = tf.Variable(initializer(shape_one))
433  weight_two = tf.Variable(initializer(shape_two))
434  ```
435
436  After:
437
438  ```python
439  initializer = tf.initializers.RandomUniform(
440    minval=minval,
441    maxval=maxval,
442    seed=seed)
443
444  weight_one = tf.Variable(initializer(shape_one, dtype=dtype))
445  weight_two = tf.Variable(initializer(shape_two, dtype=dtype))
446  ```
447
448  #### How to Map Arguments
449
450  | TF1 Arg Name          | TF2 Arg Name    | Note                       |
451  | :-------------------- | :-------------- | :------------------------- |
452  | `minval`               | `minval`    | Default changes from 0 to -0.05 |
453  | `maxval`         | `maxval`        | Default changes from 1.0 to 0.05 |
454  | `seed`             | `seed` |  |
455  | `dtype` | `dtype`   | The TF2 native api only takes it  |
456  :                     :      : as a `__call__` arg, not a constructor arg. :
457  | `partition_info`     | - |  (`__call__` arg in TF1) Not supported       |
458
459  @end_compatibility
460  """
461
462  @deprecated_args(None,
463                   "Call initializer instance with the dtype argument instead "
464                   "of passing it to the constructor", "dtype")
465  def __init__(self, minval=.0, maxval=None, seed=None, dtype=dtypes.float32):
466    self.minval = minval
467    self.maxval = maxval
468    self.seed = seed
469    self.dtype = dtypes.as_dtype(dtype)
470
471  def __call__(self, shape, dtype=None, partition_info=None):
472    if dtype is None:
473      dtype = self.dtype
474    return random_ops.random_uniform(
475        shape, self.minval, self.maxval, dtype, seed=self.seed)
476
477  def get_config(self):
478    return {
479        "minval": self.minval,
480        "maxval": self.maxval,
481        "seed": self.seed,
482        "dtype": self.dtype.name
483    }
484
485
486@tf_export(v1=["initializers.random_normal", "random_normal_initializer"])
487@deprecation.deprecated_endpoints("initializers.random_normal")
488class RandomNormal(Initializer):
489  """Initializer that generates tensors with a normal distribution.
490
491  Args:
492    mean: a python scalar or a scalar tensor. Mean of the random values to
493      generate.
494    stddev: a python scalar or a scalar tensor. Standard deviation of the random
495      values to generate.
496    seed: A Python integer. Used to create random seeds. See
497      `tf.compat.v1.set_random_seed` for behavior.
498    dtype: Default data type, used if no `dtype` argument is provided when
499      calling the initializer. Only floating point types are supported.
500
501  @compatibility(TF2)
502  Although it is a legacy `compat.v1` API, this symbol is compatible with eager
503  execution and `tf.function`.
504
505  To switch to TF2, switch to using either
506  `tf.initializers.RandomNormal` or `tf.keras.initializers.RandomNormal`
507  (neither from `compat.v1`) and
508  pass the dtype when calling the initializer. Keep in mind that
509  the default stddev and the behavior of fixed seeds have changed.
510
511  #### Structural Mapping to TF2
512
513  Before:
514
515  ```python
516  initializer = tf.compat.v1.random_normal_initializer(
517    mean=mean,
518    stddev=stddev,
519    seed=seed,
520    dtype=dtype)
521
522  weight_one = tf.Variable(initializer(shape_one))
523  weight_two = tf.Variable(initializer(shape_two))
524  ```
525
526  After:
527
528  ```python
529  initializer = tf.initializers.RandomNormal(
530    mean=mean,
531    seed=seed,
532    stddev=stddev)
533
534  weight_one = tf.Variable(initializer(shape_one, dtype=dtype))
535  weight_two = tf.Variable(initializer(shape_two, dtype=dtype))
536  ```
537
538  #### How to Map Arguments
539
540  | TF1 Arg Name       | TF2 Arg Name    | Note                       |
541  | :----------------- | :-------------- | :------------------------- |
542  | `mean`             | `mean`          | No change to defaults |
543  | `stddev`           | `stddev`        | Default changes from 1.0 to 0.05 |
544  | `seed`             | `seed`          |                                  |
545  | `dtype`            | `dtype`  | The TF2 native api only takes it as a |
546  :                    :          : `__call__` arg, not a constructor arg. :
547  | `partition_info`   | -     |  (`__call__` arg in TF1) Not supported.  |
548
549  @end_compatibility
550  """
551
552  @deprecated_args(None,
553                   "Call initializer instance with the dtype argument instead "
554                   "of passing it to the constructor", "dtype")
555  def __init__(self, mean=0.0, stddev=1.0, seed=None, dtype=dtypes.float32):
556    self.mean = mean
557    self.stddev = stddev
558    self.seed = seed
559    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
560
561  def __call__(self, shape, dtype=None, partition_info=None):
562    if dtype is None:
563      dtype = self.dtype
564    return random_ops.random_normal(
565        shape, self.mean, self.stddev, dtype, seed=self.seed)
566
567  def get_config(self):
568    return {
569        "mean": self.mean,
570        "stddev": self.stddev,
571        "seed": self.seed,
572        "dtype": self.dtype.name
573    }
574
575
576@tf_export(v1=["initializers.truncated_normal", "truncated_normal_initializer"])
577@deprecation.deprecated_endpoints("initializers.truncated_normal",
578                                  "truncated_normal_initializer")
579class TruncatedNormal(Initializer):
580  """Initializer that generates a truncated normal distribution.
581
582  These values are similar to values from a `random_normal_initializer`
583  except that values more than two standard deviations from the mean
584  are discarded and re-drawn. This is the recommended initializer for
585  neural network weights and filters.
586
587  Args:
588    mean: a python scalar or a scalar tensor. Mean of the random values to
589      generate.
590    stddev: a python scalar or a scalar tensor. Standard deviation of the random
591      values to generate.
592    seed: A Python integer. Used to create random seeds. See
593      `tf.compat.v1.set_random_seed` for behavior.
594    dtype: Default data type, used if no `dtype` argument is provided when
595      calling the initializer. Only floating point types are supported.
596
597  @compatibility(TF2)
598  Although it is a legacy `compat.v1` API, this symbol is compatible with eager
599  execution and `tf.function`.
600
601  To switch to TF2, switch to using either
602  `tf.initializers.truncated_normal` or `tf.keras.initializers.TruncatedNormal`
603  (neither from `compat.v1`) and
604  pass the dtype when calling the initializer. Keep in mind that
605  the default stddev and the behavior of fixed seeds have changed.
606
607  #### Structural Mapping to TF2
608
609  Before:
610
611  ```python
612  initializer = tf.compat.v1.truncated_normal_initializer(
613    mean=mean,
614    stddev=stddev,
615    seed=seed,
616    dtype=dtype)
617
618  weight_one = tf.Variable(initializer(shape_one))
619  weight_two = tf.Variable(initializer(shape_two))
620  ```
621
622  After:
623
624  ```python
625  initializer = tf.initializers.truncated_normal(
626    mean=mean,
627    seed=seed,
628    stddev=stddev)
629
630  weight_one = tf.Variable(initializer(shape_one, dtype=dtype))
631  weight_two = tf.Variable(initializer(shape_two, dtype=dtype))
632  ```
633
634  #### How to Map Arguments
635
636  | TF1 Arg Name          | TF2 Arg Name    | Note                       |
637  | :-------------------- | :-------------- | :------------------------- |
638  | `mean`               | `mean`        | No change to defaults |
639  | `stddev`         | `stddev`        | Default changes from 1.0 to 0.05 |
640  | `seed`             | `seed` | |
641  | `dtype` | `dtype`   | The TF2 native api only takes it  |
642  :                     :      : as a `__call__` arg, not a constructor arg. :
643  | `partition_info`     | - |  (`__call__` arg in TF1) Not supported       |
644
645  @end_compatibility
646  """
647
648  @deprecated_args(None,
649                   "Call initializer instance with the dtype argument instead "
650                   "of passing it to the constructor", "dtype")
651  def __init__(self, mean=0.0, stddev=1.0, seed=None, dtype=dtypes.float32):
652    self.mean = mean
653    self.stddev = stddev
654    self.seed = seed
655    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
656
657  def __call__(self, shape, dtype=None, partition_info=None):
658    if dtype is None:
659      dtype = self.dtype
660    return random_ops.truncated_normal(
661        shape, self.mean, self.stddev, dtype, seed=self.seed)
662
663  def get_config(self):
664    return {
665        "mean": self.mean,
666        "stddev": self.stddev,
667        "seed": self.seed,
668        "dtype": self.dtype.name
669    }
670
671
672@tf_export(v1=[
673    "initializers.uniform_unit_scaling", "uniform_unit_scaling_initializer"
674])
675@deprecation.deprecated_endpoints("uniform_unit_scaling_initializer",
676                                  "initializers.uniform_unit_scaling")
677class UniformUnitScaling(Initializer):
678  """Initializer that generates tensors without scaling variance.
679
680  When initializing a deep network, it is in principle advantageous to keep
681  the scale of the input variance constant, so it does not explode or diminish
682  by reaching the final layer. If the input is `x` and the operation `x * W`,
683  and we want to initialize `W` uniformly at random, we need to pick `W` from
684
685      [-sqrt(3) / sqrt(dim), sqrt(3) / sqrt(dim)]
686
687  to keep the scale intact, where `dim = W.shape[0]` (the size of the input).
688  A similar calculation for convolutional networks gives an analogous result
689  with `dim` equal to the product of the first 3 dimensions.  When
690  nonlinearities are present, we need to multiply this by a constant `factor`.
691  See (Sussillo et al., 2014) for deeper motivation, experiments
692  and the calculation of constants. In section 2.3 there, the constants were
693  numerically computed: for a linear layer it's 1.0, relu: ~1.43, tanh: ~1.15.
694
695  Args:
696    factor: Float.  A multiplicative factor by which the values will be scaled.
697    seed: A Python integer. Used to create random seeds. See
698      `tf.compat.v1.set_random_seed` for behavior.
699    dtype: Default data type, used if no `dtype` argument is provided when
700      calling the initializer. Only floating point types are supported.
701  References:
702      [Sussillo et al., 2014](https://arxiv.org/abs/1412.6558)
703      ([pdf](http://arxiv.org/pdf/1412.6558.pdf))
704  """
705
706  @deprecated_args(None,
707                   "Call initializer instance with the dtype argument instead "
708                   "of passing it to the constructor", "dtype")
709  @deprecated(None,
710              "Use tf.initializers.variance_scaling instead with distribution="
711              "uniform to get equivalent behavior.")
712  def __init__(self, factor=1.0, seed=None, dtype=dtypes.float32):
713    self.factor = factor
714    self.seed = seed
715    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
716
717  def __call__(self, shape, dtype=None, partition_info=None):
718    if dtype is None:
719      dtype = self.dtype
720    scale_shape = shape
721    if partition_info is not None:
722      scale_shape = partition_info.full_shape
723
724    input_size = 1.0
725    # Estimating input size is not possible to do perfectly, but we try.
726    # The estimate, obtained by multiplying all dimensions but the last one,
727    # is the right thing for matrix multiply and convolutions (see above).
728    for dim in scale_shape[:-1]:
729      input_size *= float(dim)
730    # Avoid errors when initializing zero-size tensors.
731    input_size = max(input_size, 1.0)
732    max_val = math.sqrt(3 / input_size) * self.factor
733    return random_ops.random_uniform(
734        shape, -max_val, max_val, dtype, seed=self.seed)
735
736  def get_config(self):
737    return {"factor": self.factor, "seed": self.seed, "dtype": self.dtype.name}
738
739
740@tf_export(v1=["initializers.variance_scaling", "variance_scaling_initializer"])
741@deprecation.deprecated_endpoints("initializers.variance_scaling",
742                                  "variance_scaling_initializer")
743class VarianceScaling(Initializer):
744  """Initializer capable of adapting its scale to the shape of weights tensors.
745
746  @compatibility(TF2)
747  Although it is a legacy `compat.v1` API, this symbol is compatible with eager
748  execution and `tf.function`.
749
750  To switch to TF2 APIs, move to using either
751  `tf.initializers.variance_scaling` or `tf.keras.initializers.VarianceScaling`
752  (neither from `compat.v1`) and
753  pass the dtype when calling the initializer.
754
755  #### Structural Mapping to TF2
756
757  Before:
758
759  ```python
760  initializer = tf.compat.v1.variance_scaling_initializer(
761    scale=scale,
762    mode=mode,
763    distribution=distribution
764    seed=seed,
765    dtype=dtype)
766
767  weight_one = tf.Variable(initializer(shape_one))
768  weight_two = tf.Variable(initializer(shape_two))
769  ```
770
771  After:
772
773  ```python
774  initializer = tf.keras.initializers.VarianceScaling(
775    scale=scale,
776    mode=mode,
777    distribution=distribution
778    seed=seed)
779
780  weight_one = tf.Variable(initializer(shape_one, dtype=dtype))
781  weight_two = tf.Variable(initializer(shape_two, dtype=dtype))
782  ```
783
784  #### How to Map Arguments
785
786  | TF1 Arg Name       | TF2 Arg Name    | Note                       |
787  | :----------------- | :-------------- | :------------------------- |
788  | `scale`            | `scale`        | No change to defaults       |
789  | `mode`             | `mode`         | No change to defaults       |
790  | `distribution`     | `distribution` | No change to defaults.      |
791  :                    :                : 'normal' maps to 'truncated_normal' :
792  | `seed`             | `seed`         | |
793  | `dtype`        |  `dtype` | The TF2 api only takes it  |
794  :                :          : as a `__call__` arg, not a constructor arg. :
795  | `partition_info`     | - |  (`__call__` arg in TF1) Not supported       |
796
797  @end_compatibility
798
799  With `distribution="truncated_normal" or "untruncated_normal"`,
800  samples are drawn from a truncated/untruncated normal
801  distribution with a mean of zero and a standard deviation (after truncation,
802  if used) `stddev = sqrt(scale / n)`
803  where n is:
804    - number of input units in the weight tensor, if mode = "fan_in"
805    - number of output units, if mode = "fan_out"
806    - average of the numbers of input and output units, if mode = "fan_avg"
807
808  With `distribution="uniform"`, samples are drawn from a uniform distribution
809  within [-limit, limit], with `limit = sqrt(3 * scale / n)`.
810
811  Args:
812    scale: Scaling factor (positive float).
813    mode: One of "fan_in", "fan_out", "fan_avg".
814    distribution: Random distribution to use. One of "normal", "uniform".
815    seed: A Python integer. Used to create random seeds. See
816      `tf.compat.v1.set_random_seed` for behavior.
817    dtype: Default data type, used if no `dtype` argument is provided when
818      calling the initializer. Only floating point types are supported.
819
820  Raises:
821    ValueError: In case of an invalid value for the "scale", mode" or
822      "distribution" arguments.
823  """
824
825  @deprecated_args(None,
826                   "Call initializer instance with the dtype argument instead "
827                   "of passing it to the constructor", "dtype")
828  @deprecated_arg_values(
829      None,
830      "`normal` is a deprecated alias for `truncated_normal`",
831      distribution="normal")
832  def __init__(self,
833               scale=1.0,
834               mode="fan_in",
835               distribution="truncated_normal",
836               seed=None,
837               dtype=dtypes.float32):
838    if scale <= 0.:
839      raise ValueError("Argument `scale` must be a positive float. Received: "
840                       f"{scale}")
841    if mode not in {"fan_in", "fan_out", "fan_avg"}:
842      raise ValueError("Argument `mode` should be one of ('fan_in', 'fan_out', "
843                       f"'fan_avg'). Received: {mode}")
844    distribution = distribution.lower()
845    if distribution not in {
846        "normal", "uniform", "truncated_normal", "untruncated_normal"
847    }:
848      raise ValueError("Argument `distribution` should be one of ('normal', "
849                       "uniform', 'truncated_normal', 'untruncated_normal'). "
850                       f"Received: {distribution}")
851    self.scale = scale
852    self.mode = mode
853    self.distribution = distribution
854    self.seed = seed
855    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
856
857  def __call__(self, shape, dtype=None, partition_info=None):
858    if dtype is None:
859      dtype = self.dtype
860    scale = self.scale
861    scale_shape = shape
862    if partition_info is not None:
863      scale_shape = partition_info.full_shape
864    fan_in, fan_out = _compute_fans(scale_shape)
865    if self.mode == "fan_in":
866      scale /= max(1., fan_in)
867    elif self.mode == "fan_out":
868      scale /= max(1., fan_out)
869    else:
870      scale /= max(1., (fan_in + fan_out) / 2.)
871    if self.distribution == "normal" or self.distribution == "truncated_normal":
872      # constant taken from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
873      stddev = math.sqrt(scale) / .87962566103423978
874      return random_ops.truncated_normal(
875          shape, 0.0, stddev, dtype, seed=self.seed)
876    elif self.distribution == "untruncated_normal":
877      stddev = math.sqrt(scale)
878      return random_ops.random_normal(shape, 0.0, stddev, dtype, seed=self.seed)
879    else:
880      limit = math.sqrt(3.0 * scale)
881      return random_ops.random_uniform(
882          shape, -limit, limit, dtype, seed=self.seed)
883
884  def get_config(self):
885    return {
886        "scale": self.scale,
887        "mode": self.mode,
888        "distribution": self.distribution,
889        "seed": self.seed,
890        "dtype": self.dtype.name
891    }
892
893
894@tf_export(v1=["initializers.orthogonal", "orthogonal_initializer"])
895@deprecation.deprecated_endpoints("initializers.orthogonal",
896                                  "orthogonal_initializer")
897class Orthogonal(Initializer):
898  """Initializer that generates an orthogonal matrix.
899
900  If the shape of the tensor to initialize is two-dimensional, it is initialized
901  with an orthogonal matrix obtained from the QR decomposition of a matrix of
902  random numbers drawn from a normal distribution.
903  If the matrix has fewer rows than columns then the output will have orthogonal
904  rows. Otherwise, the output will have orthogonal columns.
905
906  If the shape of the tensor to initialize is more than two-dimensional,
907  a matrix of shape `(shape[0] * ... * shape[n - 2], shape[n - 1])`
908  is initialized, where `n` is the length of the shape vector.
909  The matrix is subsequently reshaped to give a tensor of the desired shape.
910
911  Args:
912    gain: multiplicative factor to apply to the orthogonal matrix
913    seed: A Python integer. Used to create random seeds. See
914      `tf.compat.v1.set_random_seed` for behavior.
915    dtype: Default data type, used if no `dtype` argument is provided when
916      calling the initializer. Only floating point types are supported.
917  References:
918      [Saxe et al., 2014](https://openreview.net/forum?id=_wzZwKpTDF_9C)
919      ([pdf](https://arxiv.org/pdf/1312.6120.pdf))
920  """
921
922  @deprecated_args(None,
923                   "Call initializer instance with the dtype argument instead "
924                   "of passing it to the constructor", "dtype")
925  def __init__(self, gain=1.0, seed=None, dtype=dtypes.float32):
926    self.gain = gain
927    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
928    self.seed = seed
929
930  def __call__(self, shape, dtype=None, partition_info=None):
931    if dtype is None:
932      dtype = self.dtype
933    # Check the shape
934    if len(shape) < 2:
935      raise ValueError("The tensor to initialize, specified by argument `shape`"
936                       " must be at least two-dimensional. Received shape="
937                       f"{shape}")
938    # Flatten the input shape with the last dimension remaining
939    # its original shape so it works for conv2d
940    num_rows = 1
941    for dim in shape[:-1]:
942      num_rows *= dim
943    num_rows = int(num_rows)
944    num_cols = int(shape[-1])
945    if num_rows < num_cols:
946      flat_shape = (num_cols, num_rows)
947    else:
948      flat_shape = (num_rows, num_cols)
949
950    # Generate a random matrix
951    a = random_ops.random_normal(flat_shape, dtype=dtype, seed=self.seed)
952    # Compute the qr factorization
953    q, r = gen_linalg_ops.qr(a, full_matrices=False)
954    # Make Q uniform
955    d = array_ops.diag_part(r)
956    q *= math_ops.sign(d)
957    if num_rows < num_cols:
958      q = array_ops.matrix_transpose(q)
959    return self.gain * array_ops.reshape(q, shape)
960
961  def get_config(self):
962    return {"gain": self.gain, "seed": self.seed, "dtype": self.dtype.name}
963
964
965# Note these haven't been ported to TF2.0. They are not currently visible and
966# the tests are non trivial to port
967class ConvolutionDeltaOrthogonal(Initializer):
968  """Initializer that generates a delta orthogonal kernel for ConvNets.
969
970  The shape of the tensor must have length 3, 4 or 5. The number of input
971  filters must not exceed the number of output filters. The center pixels of the
972  tensor form an orthogonal matrix. Other pixels are set to be zero. See
973  algorithm 2 in (Xiao et al., 2018).
974
975
976  Args:
977    gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1.
978      The 2-norm of an input is multiplied by a factor of `gain` after applying
979      this convolution.
980    seed: A Python integer. Used to create random seeds. See
981      `tf.compat.v1.set_random_seed` for behavior.
982    dtype: Default data type, used if no `dtype` argument is provided when
983      calling the initializer. Only floating point types are supported.
984  References:
985      [Xiao et al., 2018](http://proceedings.mlr.press/v80/xiao18a.html)
986      ([pdf](http://proceedings.mlr.press/v80/xiao18a/xiao18a.pdf))
987  """
988
989  def __init__(self, gain=1.0, seed=None, dtype=dtypes.float32):
990    self.gain = gain
991    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
992    self.seed = seed
993
994  def __call__(self, shape, dtype=None, partition_info=None):
995    if dtype is None:
996      dtype = self.dtype
997    # Check the shape
998    if len(shape) < 3 or len(shape) > 5:
999      raise ValueError("The tensor to initialize, specified by argument `shape`"
1000                       " must be at least three-dimensional and at most "
1001                       f"five-dimensional. Received shape={shape}")
1002
1003    if shape[-2] > shape[-1]:
1004      raise ValueError(f"In_filters, specified by shape[-2]={shape[-2]} cannot "
1005                       "be greater than out_filters, specified by "
1006                       f"shape[-1]={shape[-1]}.")
1007
1008    # Generate a random matrix
1009    a = random_ops.random_normal([shape[-1], shape[-1]],
1010                                 dtype=dtype,
1011                                 seed=self.seed)
1012    # Compute the qr factorization
1013    q, r = gen_linalg_ops.qr(a, full_matrices=False)
1014    # Make Q uniform
1015    d = array_ops.diag_part(r)
1016    q *= math_ops.sign(d)
1017    q = q[:shape[-2], :]
1018    q *= math_ops.cast(self.gain, dtype=dtype)
1019    if len(shape) == 3:
1020      weight = array_ops.scatter_nd([[(shape[0] - 1) // 2]],
1021                                    array_ops.expand_dims(q, 0), shape)
1022    elif len(shape) == 4:
1023      weight = array_ops.scatter_nd([[(shape[0] - 1) // 2,
1024                                      (shape[1] - 1) // 2]],
1025                                    array_ops.expand_dims(q, 0), shape)
1026    else:
1027      weight = array_ops.scatter_nd([[(shape[0] - 1) // 2, (shape[1] - 1) // 2,
1028                                      (shape[2] - 1) // 2]],
1029                                    array_ops.expand_dims(q, 0), shape)
1030    return weight
1031
1032  def get_config(self):
1033    return {"gain": self.gain, "seed": self.seed, "dtype": self.dtype.name}
1034
1035
1036class ConvolutionOrthogonal(Initializer):
1037  """Initializer that generates orthogonal kernel for ConvNets.
1038
1039  Base class used to construct 1D, 2D and 3D orthogonal kernels for convolution.
1040
1041  Args:
1042    gain: multiplicative factor to apply to the orthogonal matrix. Default is 1.
1043      The 2-norm of an input is multiplied by a factor of `gain` after applying
1044      this convolution.
1045    seed: A Python integer. Used to create random seeds. See
1046      `tf.compat.v1.set_random_seed` for behavior.
1047    dtype: Default data type, used if no `dtype` argument is provided when
1048      calling the initializer. Only floating point types are supported.
1049  References:
1050      [Xiao et al., 2018](http://proceedings.mlr.press/v80/xiao18a.html)
1051      ([pdf](http://proceedings.mlr.press/v80/xiao18a/xiao18a.pdf))
1052  """
1053
1054  def __init__(self, gain=1.0, seed=None, dtype=dtypes.float32):
1055    self.gain = gain
1056    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
1057    self.seed = seed
1058
1059  def __call__(self, shape, dtype=None, partition_info=None):
1060    raise NotImplementedError
1061
1062  def get_config(self):
1063    return {"gain": self.gain, "seed": self.seed, "dtype": self.dtype.name}
1064
1065  # Helper functions.
1066  def _orthogonal_matrix(self, n):
1067    """Construct an n x n orthogonal matrix.
1068
1069    Args:
1070      n: Dimension.
1071
1072    Returns:
1073      A n x n orthogonal matrix.
1074    """
1075    a = random_ops.random_normal([n, n], dtype=self.dtype, seed=self.seed)
1076    if self.seed:
1077      self.seed += 1
1078    q, r = gen_linalg_ops.qr(a)
1079    d = array_ops.diag_part(r)
1080    # make q uniform
1081    q *= math_ops.sign(d)
1082    return q
1083
1084  def _symmetric_projection(self, n):
1085    """Compute a n x n symmetric projection matrix.
1086
1087    Args:
1088      n: Dimension.
1089
1090    Returns:
1091      A n x n symmetric projection matrix, i.e. a matrix P s.t. P=P*P, P=P^T.
1092    """
1093    q = self._orthogonal_matrix(n)
1094    # randomly zeroing out some columns
1095    mask = math_ops.cast(
1096        random_ops.random_normal([n], seed=self.seed) > 0, self.dtype)
1097    if self.seed:
1098      self.seed += 1
1099    c = math_ops.multiply(q, mask)
1100    return math_ops.matmul(c, array_ops.matrix_transpose(c))
1101
1102
1103class ConvolutionOrthogonal2D(ConvolutionOrthogonal):
1104  """Initializer that generates a 2D orthogonal kernel for ConvNets.
1105
1106  The shape of the tensor must have length 4. The number of input
1107  filters must not exceed the number of output filters.
1108  The orthogonality(==isometry) is exact when the inputs are circular padded.
1109  There are finite-width effects with non-circular padding (e.g. zero padding).
1110  See algorithm 1 in (Xiao et al., 2018).
1111
1112  Args:
1113    gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1.
1114      This has the effect of scaling the output 2-norm by a factor of `gain`.
1115    seed: A Python integer. Used to create random seeds. See
1116      `tf.compat.v1.set_random_seed` for behavior.
1117    dtype: Default data type, used if no `dtype` argument is provided when
1118      calling the initializer. Only floating point types are supported.
1119  References:
1120      [Xiao et al., 2018](http://proceedings.mlr.press/v80/xiao18a.html)
1121      ([pdf](http://proceedings.mlr.press/v80/xiao18a/xiao18a.pdf))
1122  """
1123
1124  def __call__(self, shape, dtype=None, partition_info=None):
1125    if dtype is None:
1126      dtype = self.dtype
1127    if len(shape) != 4:
1128      raise ValueError("The tensor to initialize, specified by argument `shape`"
1129                       f" must be four-dimensional. Received: {shape}")
1130
1131    if shape[-2] > shape[-1]:
1132      raise ValueError(f"In_filters, specified by shape[-2]={shape[-2]} cannot "
1133                       "be greater than out_filters, specified by "
1134                       f"shape[-1]={shape[-1]}.")
1135
1136    if shape[0] != shape[1]:
1137      raise ValueError(f"Kernel sizes, specified by shape[0]={shape[0]} and "
1138                       f"shape[1]={shape[1]} must be equal.")
1139
1140    kernel = self._orthogonal_kernel(shape[0], shape[2], shape[3])
1141    kernel *= math_ops.cast(self.gain, dtype=dtype)
1142    return kernel
1143
1144  def _dict_to_tensor(self, x, k1, k2):
1145    """Convert a dictionary to a tensor.
1146
1147    Args:
1148      x: A k1 * k2 dictionary.
1149      k1: First dimension of x.
1150      k2: Second dimension of x.
1151
1152    Returns:
1153      A k1 * k2 tensor.
1154    """
1155
1156    return array_ops.stack([array_ops.stack([x[i, j] for j in range(k2)])
1157                            for i in range(k1)])
1158
1159  def _block_orth(self, p1, p2):
1160    """Construct a 2 x 2 kernel.
1161
1162    Used to construct orthgonal kernel.
1163
1164    Args:
1165      p1: A symmetric projection matrix.
1166      p2: A symmetric projection matrix.
1167
1168    Returns:
1169      A 2 x 2 kernel [[p1p2,         p1(1-p2)],
1170                      [(1-p1)p2, (1-p1)(1-p2)]].
1171    Raises:
1172      ValueError: If the dimensions of p1 and p2 are different.
1173    """
1174    if p1.shape.as_list() != p2.shape.as_list():
1175      raise ValueError("The dimension of the matrices must be the same. "
1176                       f"Received p1.shape={p1.shape} and p2.shape={p2.shape}.")
1177    n = p1.shape.as_list()[0]
1178    kernel2x2 = {}
1179    eye = linalg_ops_impl.eye(n, dtype=self.dtype)
1180    kernel2x2[0, 0] = math_ops.matmul(p1, p2)
1181    kernel2x2[0, 1] = math_ops.matmul(p1, (eye - p2))
1182    kernel2x2[1, 0] = math_ops.matmul((eye - p1), p2)
1183    kernel2x2[1, 1] = math_ops.matmul((eye - p1), (eye - p2))
1184
1185    return kernel2x2
1186
1187  def _matrix_conv(self, m1, m2):
1188    """Matrix convolution.
1189
1190    Args:
1191      m1: A k x k dictionary, each element is a n x n matrix.
1192      m2: A l x l dictionary, each element is a n x n matrix.
1193
1194    Returns:
1195      (k + l - 1) * (k + l - 1) dictionary each element is a n x n matrix.
1196    Raises:
1197      ValueError: if the entries of m1 and m2 are of different dimensions.
1198    """
1199
1200    n = (m1[0, 0]).shape.as_list()[0]
1201    if n != (m2[0, 0]).shape.as_list()[0]:
1202      raise ValueError("The entries in matrices m1 and m2 must have the same "
1203                       f"dimensions. Received m1[0, 0].shape={m1[0, 0].shape} "
1204                       f"and m2[0, 0].shape={m2[0, 0].shape}.")
1205    k = int(np.sqrt(len(m1)))
1206    l = int(np.sqrt(len(m2)))
1207    result = {}
1208    size = k + l - 1
1209    # Compute matrix convolution between m1 and m2.
1210    for i in range(size):
1211      for j in range(size):
1212        result[i, j] = array_ops.zeros([n, n], self.dtype)
1213        for index1 in range(min(k, i + 1)):
1214          for index2 in range(min(k, j + 1)):
1215            if (i - index1) < l and (j - index2) < l:
1216              result[i, j] += math_ops.matmul(m1[index1, index2],
1217                                              m2[i - index1, j - index2])
1218    return result
1219
1220  def _orthogonal_kernel(self, ksize, cin, cout):
1221    """Construct orthogonal kernel for convolution.
1222
1223    Args:
1224      ksize: Kernel size.
1225      cin: Number of input channels.
1226      cout: Number of output channels.
1227
1228    Returns:
1229      An [ksize, ksize, cin, cout] orthogonal kernel.
1230    Raises:
1231      ValueError: If cin > cout.
1232    """
1233    if cin > cout:
1234      raise ValueError(f"The number of input channels (cin={cin}) cannot exceed"
1235                       f" the number of output channels (cout={cout}).")
1236    orth = self._orthogonal_matrix(cout)[0:cin, :]
1237    if ksize == 1:
1238      return array_ops.expand_dims(array_ops.expand_dims(orth, 0), 0)
1239
1240    p = self._block_orth(
1241        self._symmetric_projection(cout), self._symmetric_projection(cout))
1242    for _ in range(ksize - 2):
1243      temp = self._block_orth(
1244          self._symmetric_projection(cout), self._symmetric_projection(cout))
1245      p = self._matrix_conv(p, temp)
1246    for i in range(ksize):
1247      for j in range(ksize):
1248        p[i, j] = math_ops.matmul(orth, p[i, j])
1249
1250    return self._dict_to_tensor(p, ksize, ksize)
1251
1252
1253class ConvolutionOrthogonal1D(ConvolutionOrthogonal):
1254  """Initializer that generates a 1D orthogonal kernel for ConvNets.
1255
1256  The shape of the tensor must have length 3. The number of input
1257  filters must not exceed the number of output filters.
1258  The orthogonality(==isometry) is exact when the inputs are circular padded.
1259  There are finite-width effects with non-circular padding (e.g. zero padding).
1260  See algorithm 1 in (Xiao et al., 2018).
1261
1262  Args:
1263    gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1.
1264      The 2-norm of an input is multiplied by a factor of `gain` after applying
1265      this convolution.
1266    seed: A Python integer. Used to create random seeds. See
1267      `tf.compat.v1.set_random_seed` for behavior.
1268    dtype: Default data type, used if no `dtype` argument is provided when
1269      calling the initializer. Only floating point types are supported.
1270  References:
1271      [Xiao et al., 2018](http://proceedings.mlr.press/v80/xiao18a.html)
1272      ([pdf](http://proceedings.mlr.press/v80/xiao18a/xiao18a.pdf))
1273  """
1274
1275  def __call__(self, shape, dtype=None, partition_info=None):
1276    if dtype is None:
1277      dtype = self.dtype
1278    if len(shape) != 3:
1279      raise ValueError("The tensor to initialize, specified by argument `shape`"
1280                       f" must be three-dimensional. Received shape={shape}")
1281
1282    if shape[-2] > shape[-1]:
1283      raise ValueError(f"In_filters, specified by shape[-2]={shape[-2]} cannot "
1284                       "be greater than out_filters, specified by "
1285                       f"shape[-1]={shape[-1]}.")
1286
1287    kernel = self._orthogonal_kernel(shape[0], shape[-2], shape[-1])
1288    kernel *= math_ops.cast(self.gain, dtype=dtype)
1289    return kernel
1290
1291  def _dict_to_tensor(self, x, k):
1292    """Convert a dictionary to a tensor.
1293
1294    Args:
1295      x: A dictionary of length k.
1296      k: Dimension of x.
1297
1298    Returns:
1299      A tensor with the same dimension.
1300    """
1301
1302    return array_ops.stack([x[i] for i in range(k)])
1303
1304  def _block_orth(self, projection_matrix):
1305    """Construct a kernel.
1306
1307    Used to construct orthgonal kernel.
1308
1309    Args:
1310      projection_matrix: A symmetric projection matrix of size n x n.
1311
1312    Returns:
1313      [projection_matrix, (1 - projection_matrix)].
1314    """
1315    n = projection_matrix.shape.as_list()[0]
1316    kernel = {}
1317    eye = linalg_ops_impl.eye(n, dtype=self.dtype)
1318    kernel[0] = projection_matrix
1319    kernel[1] = eye - projection_matrix
1320    return kernel
1321
1322  def _matrix_conv(self, m1, m2):
1323    """Matrix convolution.
1324
1325    Args:
1326      m1: A dictionary of length k, each element is a n x n matrix.
1327      m2: A dictionary of length l, each element is a n x n matrix.
1328
1329    Returns:
1330      (k + l - 1)  dictionary each element is a n x n matrix.
1331    Raises:
1332      ValueError: Ff the entries of m1 and m2 are of different dimensions.
1333    """
1334
1335    n = (m1[0]).shape.as_list()[0]
1336    if n != (m2[0]).shape.as_list()[0]:
1337      raise ValueError("The entries in matrices m1 and m2 must have the same "
1338                       f"dimensions. Received m1[0].shape={m1[0].shape} "
1339                       f"and m2[0].shape={m2[0].shape}.")
1340    k = len(m1)
1341    l = len(m2)
1342    result = {}
1343    size = k + l - 1
1344    # Compute matrix convolution between m1 and m2.
1345    for i in range(size):
1346      result[i] = array_ops.zeros([n, n], self.dtype)
1347      for index in range(min(k, i + 1)):
1348        if (i - index) < l:
1349          result[i] += math_ops.matmul(m1[index], m2[i - index])
1350    return result
1351
1352  def _orthogonal_kernel(self, ksize, cin, cout):
1353    """Construct orthogonal kernel for convolution.
1354
1355    Args:
1356      ksize: Kernel size.
1357      cin: Number of input channels.
1358      cout: Number of output channels.
1359
1360    Returns:
1361      An [ksize, ksize, cin, cout] orthogonal kernel.
1362    Raises:
1363      ValueError: If cin > cout.
1364    """
1365    if cin > cout:
1366      raise ValueError(f"The number of input channels (cin={cin}) cannot exceed"
1367                       f" the number of output channels (cout={cout}).")
1368    orth = self._orthogonal_matrix(cout)[0:cin, :]
1369    if ksize == 1:
1370      return array_ops.expand_dims(orth, 0)
1371
1372    p = self._block_orth(self._symmetric_projection(cout))
1373    for _ in range(ksize - 2):
1374      temp = self._block_orth(self._symmetric_projection(cout))
1375      p = self._matrix_conv(p, temp)
1376    for i in range(ksize):
1377      p[i] = math_ops.matmul(orth, p[i])
1378
1379    return self._dict_to_tensor(p, ksize)
1380
1381
1382class ConvolutionOrthogonal3D(ConvolutionOrthogonal):
1383  """Initializer that generates a 3D orthogonal kernel for ConvNets.
1384
1385  The shape of the tensor must have length 5. The number of input
1386  filters must not exceed the number of output filters.
1387  The orthogonality(==isometry) is exact when the inputs are circular padded.
1388  There are finite-width effects with non-circular padding (e.g. zero padding).
1389  See algorithm 1 (Xiao et al., 2018).
1390
1391  Args:
1392    gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1.
1393      The 2-norm of an input is multiplied by a factor of `gain` after applying
1394      this convolution.
1395    seed: A Python integer. Used to create random seeds. See
1396      `tf.compat.v1.set_random_seed` for behavior.
1397    dtype: Default data type, used if no `dtype` argument is provided when
1398      calling the initializer. Only floating point types are supported.
1399  References:
1400      [Xiao et al., 2018](http://proceedings.mlr.press/v80/xiao18a.html)
1401      ([pdf](http://proceedings.mlr.press/v80/xiao18a/xiao18a.pdf))
1402  """
1403
1404  def __call__(self, shape, dtype=None, partition_info=None):
1405    if dtype is None:
1406      dtype = self.dtype
1407    if len(shape) != 5:
1408      raise ValueError("The tensor to initialize, specified by argument `shape`"
1409                       f" must be five-dimensional. Received shape={shape}")
1410
1411    if shape[-2] > shape[-1]:
1412      raise ValueError(f"In_filters, specified by shape[-2]={shape[-2]} cannot "
1413                       "be greater than out_filters, specified by "
1414                       f"shape[-1]={shape[-1]}.")
1415
1416    if shape[0] != shape[1] or shape[0] != shape[2]:
1417      raise ValueError(f"Kernel sizes, specified by shape[0]={shape[0]},  "
1418                       f"shape[1]={shape[1]} and shape[2]={shape[2]} must be "
1419                       "equal.")
1420
1421    kernel = self._orthogonal_kernel(shape[0], shape[-2], shape[-1])
1422    kernel *= math_ops.cast(self.gain, dtype=dtype)
1423    return kernel
1424
1425  def _dict_to_tensor(self, x, k1, k2, k3):
1426    """Convert a dictionary to a tensor.
1427
1428    Args:
1429      x: A k1 * k2 dictionary.
1430      k1: First dimension of x.
1431      k2: Second dimension of x.
1432      k3: Third dimension of x.
1433
1434    Returns:
1435      A k1 * k2 * k3 tensor.
1436    """
1437
1438    return array_ops.stack([array_ops.stack(
1439        [array_ops.stack([x[i, j, k] for k in range(k3)])
1440         for j in range(k2)]) for i in range(k1)])
1441
1442  def _block_orth(self, p1, p2, p3):
1443    """Construct a 3 x 3 kernel.
1444
1445    Used to construct orthgonal kernel.
1446
1447    Args:
1448      p1: A symmetric projection matrix.
1449      p2: A symmetric projection matrix.
1450      p3: A symmetric projection matrix.
1451
1452    Returns:
1453      A 2 x 2 x 2 kernel.
1454    Raises:
1455      ValueError: If the dimensions of p1, p2 and p3 are different.
1456    """
1457    p1_shape = p1.shape.as_list()
1458    if p1_shape != p2.shape.as_list() or p1_shape != p3.shape.as_list():
1459      raise ValueError("The dimension of the matrices must be the same. "
1460                       f"Received p1.shape={p1.shape}, p2.shape={p2.shape} and"
1461                       f" p3.shape={p3.shape}.")
1462    n = p1_shape[0]
1463    eye = linalg_ops_impl.eye(n, dtype=self.dtype)
1464    kernel2x2x2 = {}
1465
1466    def matmul(p1, p2, p3):
1467      return math_ops.matmul(math_ops.matmul(p1, p2), p3)
1468
1469    def cast(i, p):
1470      """Return p or (1-p)."""
1471      return i * p + (1 - i) * (eye - p)
1472
1473    for i in [0, 1]:
1474      for j in [0, 1]:
1475        for k in [0, 1]:
1476          kernel2x2x2[i, j, k] = matmul(cast(i, p1), cast(j, p2), cast(k, p3))
1477    return kernel2x2x2
1478
1479  def _matrix_conv(self, m1, m2):
1480    """Matrix convolution.
1481
1482    Args:
1483      m1: is a k x k x k  dictionary, each element is a n x n matrix.
1484      m2: is a l x l x l dictionary, each element is a n x n matrix.
1485
1486    Returns:
1487      (k + l - 1) x (k + l - 1) x (k + l - 1) dictionary each
1488      element is a n x n matrix.
1489    Raises:
1490      ValueError: if the entries of m1 and m2 are of different dimensions.
1491    """
1492
1493    n = (m1[0, 0, 0]).shape.as_list()[0]
1494    if n != (m2[0, 0, 0]).shape.as_list()[0]:
1495      raise ValueError("The entries in matrices m1 and m2 must have the same "
1496                       "dimensions. Received m1[0, 0, 0].shape="
1497                       f"{m1[0, 0, 0].shape} and m2[0, 0, 0].shape="
1498                       f"{m2[0, 0, 0].shape}.")
1499    k = int(np.cbrt(len(m1)))
1500    l = int(np.cbrt(len(m2)))
1501    result = {}
1502    size = k + l - 1
1503    # Compute matrix convolution between m1 and m2.
1504    for i in range(size):
1505      for j in range(size):
1506        for r in range(size):
1507          result[i, j, r] = array_ops.zeros([n, n], self.dtype)
1508          for index1 in range(min(k, i + 1)):
1509            for index2 in range(min(k, j + 1)):
1510              for index3 in range(min(k, r + 1)):
1511                if (i - index1) < l and (j - index2) < l and (r - index3) < l:
1512                  result[i, j, r] += math_ops.matmul(
1513                      m1[index1, index2, index3],
1514                      m2[i - index1, j - index2, r - index3])
1515    return result
1516
1517  def _orthogonal_kernel(self, ksize, cin, cout):
1518    """Construct orthogonal kernel for convolution.
1519
1520    Args:
1521      ksize: Kernel size.
1522      cin: Number of input channels.
1523      cout: Number of output channels.
1524
1525    Returns:
1526      An [ksize, ksize, ksize, cin, cout] orthogonal kernel.
1527    Raises:
1528      ValueError: If cin > cout.
1529    """
1530    if cin > cout:
1531      raise ValueError(f"The number of input channels (cin={cin}) cannot exceed"
1532                       f" the number of output channels (cout={cout}).")
1533    orth = self._orthogonal_matrix(cout)[0:cin, :]
1534    if ksize == 1:
1535      return array_ops.expand_dims(
1536          array_ops.expand_dims(array_ops.expand_dims(orth, 0), 0), 0)
1537
1538    p = self._block_orth(
1539        self._symmetric_projection(cout), self._symmetric_projection(cout),
1540        self._symmetric_projection(cout))
1541    for _ in range(ksize - 2):
1542      temp = self._block_orth(
1543          self._symmetric_projection(cout), self._symmetric_projection(cout),
1544          self._symmetric_projection(cout))
1545      p = self._matrix_conv(p, temp)
1546    for i in range(ksize):
1547      for j in range(ksize):
1548        for k in range(ksize):
1549          p[i, j, k] = math_ops.matmul(orth, p[i, j, k])
1550
1551    return self._dict_to_tensor(p, ksize, ksize, ksize)
1552
1553
1554@tf_export(v1=["initializers.identity"])
1555@deprecation.deprecated_endpoints("initializers.identity")
1556class Identity(Initializer):
1557  """Initializer that generates the identity matrix.
1558
1559  Only use for 2D matrices.
1560
1561  Args:
1562    gain: Multiplicative factor to apply to the identity matrix.
1563    dtype: Default data type, used if no `dtype` argument is provided when
1564      calling the initializer. Only floating point types are supported.
1565  """
1566
1567  @deprecated_args(None,
1568                   "Call initializer instance with the dtype argument instead "
1569                   "of passing it to the constructor", "dtype")
1570  def __init__(self, gain=1.0, dtype=dtypes.float32):
1571    self.gain = gain
1572    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
1573
1574  def __call__(self, shape, dtype=None, partition_info=None):
1575    full_shape = shape if partition_info is None else partition_info.full_shape
1576    if len(full_shape) != 2:
1577      raise ValueError("The tensor to initialize, specified by argument `shape`"
1578                       " must be at least two-dimensional. Received shape="
1579                       f"{shape}")
1580    if dtype is None:
1581      dtype = self.dtype
1582    if isinstance(full_shape, tensor_shape.TensorShape):
1583      full_shape = full_shape.as_list()
1584    initializer = linalg_ops_impl.eye(*full_shape, dtype=dtype)
1585    if partition_info is not None:
1586      initializer = array_ops.slice(initializer, partition_info.var_offset,
1587                                    shape)
1588    return self.gain * initializer
1589
1590  def get_config(self):
1591    return {"gain": self.gain, "dtype": self.dtype.name}
1592
1593
1594@tf_export(v1=["glorot_uniform_initializer", "initializers.glorot_uniform"])
1595@deprecation.deprecated_endpoints("glorot_uniform_initializer",
1596                                  "initializers.glorot_uniform")
1597class GlorotUniform(VarianceScaling):
1598  """The Glorot uniform initializer, also called Xavier uniform initializer.
1599
1600  It draws samples from a uniform distribution within [-limit, limit]
1601  where `limit` is `sqrt(6 / (fan_in + fan_out))`
1602  where `fan_in` is the number of input units in the weight tensor
1603  and `fan_out` is the number of output units in the weight tensor.
1604
1605  Args:
1606    seed: A Python integer. Used to create random seeds. See
1607      `tf.compat.v1.set_random_seed` for behavior.
1608    dtype: Default data type, used if no `dtype` argument is provided when
1609      calling the initializer. Only floating point types are supported.
1610  References:
1611      [Glorot et al., 2010](http://proceedings.mlr.press/v9/glorot10a.html)
1612      ([pdf](http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf))
1613  """
1614
1615  @deprecated_args(None,
1616                   "Call initializer instance with the dtype argument instead "
1617                   "of passing it to the constructor", "dtype")
1618  def __init__(self, seed=None, dtype=dtypes.float32):
1619    super(GlorotUniform, self).__init__(
1620        scale=1.0, mode="fan_avg", distribution="uniform", seed=seed)
1621
1622  def get_config(self):
1623    return {"seed": self.seed, "dtype": self.dtype.name}
1624
1625
1626@tf_export(v1=["glorot_normal_initializer", "initializers.glorot_normal"])
1627@deprecation.deprecated_endpoints("glorot_normal_initializer",
1628                                  "initializers.glorot_normal")
1629class GlorotNormal(VarianceScaling):
1630  """The Glorot normal initializer, also called Xavier normal initializer.
1631
1632  It draws samples from a truncated normal distribution centered on 0
1633  with standard deviation (after truncation) given by
1634  `stddev = sqrt(2 / (fan_in + fan_out))` where `fan_in` is the number
1635  of input units in the weight tensor and `fan_out` is the number of
1636  output units in the weight tensor.
1637
1638  Args:
1639    seed: A Python integer. Used to create random seeds. See
1640      `tf.compat.v1.set_random_seed` for behavior.
1641    dtype: Default data type, used if no `dtype` argument is provided when
1642      calling the initializer. Only floating point types are supported.
1643  References:
1644      [Glorot et al., 2010](http://proceedings.mlr.press/v9/glorot10a.html)
1645      ([pdf](http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf))
1646  """
1647
1648  @deprecated_args(None,
1649                   "Call initializer instance with the dtype argument instead "
1650                   "of passing it to the constructor", "dtype")
1651  def __init__(self, seed=None, dtype=dtypes.float32):
1652    super(GlorotNormal, self).__init__(
1653        scale=1.0, mode="fan_avg", distribution="truncated_normal", seed=seed)
1654
1655  def get_config(self):
1656    return {"seed": self.seed, "dtype": self.dtype.name}
1657
1658
1659# Aliases.
1660
1661# pylint: disable=invalid-name
1662zeros_initializer = Zeros
1663ones_initializer = Ones
1664constant_initializer = Constant
1665random_uniform_initializer = RandomUniform
1666random_normal_initializer = RandomNormal
1667truncated_normal_initializer = TruncatedNormal
1668uniform_unit_scaling_initializer = UniformUnitScaling
1669variance_scaling_initializer = VarianceScaling
1670glorot_uniform_initializer = GlorotUniform
1671glorot_normal_initializer = GlorotNormal
1672orthogonal_initializer = Orthogonal
1673identity_initializer = Identity
1674convolutional_delta_orthogonal = ConvolutionDeltaOrthogonal
1675convolutional_orthogonal_1d = ConvolutionOrthogonal1D
1676convolutional_orthogonal_2d = ConvolutionOrthogonal2D
1677convolutional_orthogonal_3d = ConvolutionOrthogonal3D
1678# pylint: enable=invalid-name
1679
1680
1681@tf_export(v1=["initializers.lecun_normal"])
1682def lecun_normal(seed=None):
1683  """LeCun normal initializer.
1684
1685  It draws samples from a truncated normal distribution centered on 0
1686  with standard deviation (after truncation) given by
1687  `stddev = sqrt(1 / fan_in)` where `fan_in` is the number of
1688  input units in the weight tensor.
1689
1690  Args:
1691      seed: A Python integer. Used to seed the random generator.
1692
1693  Returns:
1694      An initializer.
1695
1696  References:
1697      - Self-Normalizing Neural Networks,
1698      [Klambauer et al.,
1699      2017](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks)
1700      # pylint: disable=line-too-long
1701      ([pdf](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf))
1702      - Efficient Backprop,
1703      [Lecun et al., 1998](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf)
1704  """
1705  return VarianceScaling(
1706      scale=1., mode="fan_in", distribution="truncated_normal", seed=seed)
1707
1708
1709@tf_export(v1=["initializers.lecun_uniform"])
1710def lecun_uniform(seed=None):
1711  """LeCun uniform initializer.
1712
1713  It draws samples from a uniform distribution within [-limit, limit]
1714  where `limit` is `sqrt(3 / fan_in)`
1715  where `fan_in` is the number of input units in the weight tensor.
1716
1717  Args:
1718      seed: A Python integer. Used to seed the random generator.
1719
1720  Returns:
1721      An initializer.
1722
1723  References:
1724      - Self-Normalizing Neural Networks,
1725      [Klambauer et al.,
1726      2017](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks)
1727      # pylint: disable=line-too-long
1728      ([pdf](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf))
1729      - Efficient Backprop,
1730      [Lecun et al., 1998](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf)
1731  """
1732  return VarianceScaling(
1733      scale=1., mode="fan_in", distribution="uniform", seed=seed)
1734
1735
1736@tf_export(v1=["initializers.he_normal"])
1737def he_normal(seed=None):
1738  """He normal initializer.
1739
1740  It draws samples from a truncated normal distribution centered on 0
1741  with standard deviation (after truncation) given by
1742  `stddev = sqrt(2 / fan_in)` where `fan_in` is the number of
1743  input units in the weight tensor.
1744
1745  Args:
1746      seed: A Python integer. Used to seed the random generator.
1747
1748  Returns:
1749      An initializer.
1750
1751  References:
1752      [He et al., 2015]
1753      (https://www.cv-foundation.org/openaccess/content_iccv_2015/html/He_Delving_Deep_into_ICCV_2015_paper.html)
1754      # pylint: disable=line-too-long
1755      ([pdf](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/He_Delving_Deep_into_ICCV_2015_paper.pdf))
1756  """
1757  return VarianceScaling(
1758      scale=2., mode="fan_in", distribution="truncated_normal", seed=seed)
1759
1760
1761@tf_export(v1=["initializers.he_uniform"])
1762def he_uniform(seed=None):
1763  """He uniform variance scaling initializer.
1764
1765  It draws samples from a uniform distribution within [-limit, limit]
1766  where `limit` is `sqrt(6 / fan_in)`
1767  where `fan_in` is the number of input units in the weight tensor.
1768
1769  Args:
1770      seed: A Python integer. Used to seed the random generator.
1771
1772  Returns:
1773      An initializer.
1774
1775  References:
1776      [He et al., 2015]
1777      (https://www.cv-foundation.org/openaccess/content_iccv_2015/html/He_Delving_Deep_into_ICCV_2015_paper.html)
1778      # pylint: disable=line-too-long
1779      ([pdf](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/He_Delving_Deep_into_ICCV_2015_paper.pdf))
1780  """
1781  return VarianceScaling(
1782      scale=2., mode="fan_in", distribution="uniform", seed=seed)
1783
1784
1785# Utility functions.
1786
1787
1788def _compute_fans(shape):
1789  """Computes the number of input and output units for a weight shape.
1790
1791  Args:
1792    shape: Integer shape tuple or TF tensor shape.
1793
1794  Returns:
1795    A tuple of integer scalars (fan_in, fan_out).
1796  """
1797  if len(shape) < 1:  # Just to avoid errors for constants.
1798    fan_in = fan_out = 1
1799  elif len(shape) == 1:
1800    fan_in = fan_out = shape[0]
1801  elif len(shape) == 2:
1802    fan_in = shape[0]
1803    fan_out = shape[1]
1804  else:
1805    # Assuming convolution kernels (2D, 3D, or more).
1806    # kernel shape: (..., input_depth, depth)
1807    receptive_field_size = 1
1808    for dim in shape[:-2]:
1809      receptive_field_size *= dim
1810    fan_in = shape[-2] * receptive_field_size
1811    fan_out = shape[-1] * receptive_field_size
1812  return int(fan_in), int(fan_out)
1813
1814
1815def _assert_float_dtype(dtype):
1816  """Validate and return floating point type based on `dtype`.
1817
1818  `dtype` must be a floating point type.
1819
1820  Args:
1821    dtype: The data type to validate.
1822
1823  Returns:
1824    Validated type.
1825
1826  Raises:
1827    ValueError: if `dtype` is not a floating point type.
1828  """
1829  if not dtype.is_floating:
1830    raise ValueError("Argument `dtype` is expected to be floating point. "
1831                     f"Received: {dtype}.")
1832  return dtype
1833