xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/init.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/Export.h>
4 #include <torch/enum.h>
5 #include <torch/types.h>
6 
7 namespace torch {
8 namespace nn {
9 namespace init {
10 
11 using NonlinearityType = std::variant<
12     enumtype::kLinear,
13     enumtype::kConv1D,
14     enumtype::kConv2D,
15     enumtype::kConv3D,
16     enumtype::kConvTranspose1D,
17     enumtype::kConvTranspose2D,
18     enumtype::kConvTranspose3D,
19     enumtype::kSigmoid,
20     enumtype::kTanh,
21     enumtype::kReLU,
22     enumtype::kLeakyReLU>;
23 
24 using FanModeType = std::variant<enumtype::kFanIn, enumtype::kFanOut>;
25 
26 } // namespace init
27 } // namespace nn
28 
29 namespace nn {
30 namespace init {
31 
32 /// Return the recommended gain value for the given nonlinearity function.
33 TORCH_API double calculate_gain(
34     NonlinearityType nonlinearity,
35     double param = 0.01);
36 
37 /// Fills the given `tensor` with the provided `value` in-place, and returns it.
38 /// No gradient will be recorded for this operation.
39 TORCH_API Tensor constant_(Tensor tensor, Scalar value);
40 
41 /// Fills the given `tensor` with the Dirac delta function in-place, and returns
42 /// it. No gradient will be recorded for this operation.
43 TORCH_API Tensor dirac_(Tensor tensor);
44 
45 /// Fills the given 2-dimensional `matrix` with an identity matrix.
46 /// No gradient will be recorded for this operation.
47 TORCH_API Tensor eye_(Tensor matrix);
48 
49 /// Fills the given 2-dimensional `matrix` with values drawn from a normal
50 /// distribution parameterized by `mean` and `std`.
51 /// No gradient will be recorded for this operation.
52 TORCH_API Tensor normal_(Tensor tensor, double mean = 0, double std = 1);
53 
54 /// Fills the given `tensor` with ones.
55 /// No gradient will be recorded for this operation.
56 TORCH_API Tensor ones_(Tensor tensor);
57 
58 /// Fills the input `Tensor` with a (semi) orthogonal matrix, as described in
59 /// "Exact solutions to the nonlinear dynamics of learning in deep linear neural
60 /// networks" - Saxe, A. et al. (2013). The input tensor must have at least 2
61 /// dimensions, and for tensors with more than 2 dimensions the trailing
62 /// dimensions are flattened.
63 /// No gradient will be recorded for this operation.
64 TORCH_API Tensor orthogonal_(Tensor tensor, double gain = 1.0);
65 
66 /// Fills the 2D input `Tensor` as a sparse matrix, where the
67 /// non-zero elements will be drawn from a centered normal distribution
68 /// with the given standard deviation `std`, as described in "Deep learning via
69 /// Hessian-free optimization" - Martens, J. (2010). The `sparsity` is a real
70 /// value between 0 and 1 that controls the fraction of elements in each column
71 /// to be set to zero.
72 /// No gradient will be recorded for this operation.
73 TORCH_API Tensor sparse_(Tensor tensor, double sparsity, double std = 0.01);
74 
75 /// Fills the given 2-dimensional `matrix` with values drawn from a uniform
76 /// distribution parameterized by `low` and `high`.
77 /// No gradient will be recorded for this operation.
78 TORCH_API Tensor uniform_(Tensor tensor, double low = 0, double high = 1);
79 
80 /// Fills the input `Tensor` with values according to the method
81 /// described in "Delving deep into rectifiers: Surpassing human-level
82 /// performance on ImageNet classification" - He, K. et al. (2015), using a
83 /// normal distribution. Also known as He initialization.
84 /// No gradient will be recorded for this operation.
85 TORCH_API Tensor kaiming_normal_(
86     Tensor tensor,
87     double a = 0,
88     FanModeType mode = torch::kFanIn,
89     NonlinearityType nonlinearity = torch::kLeakyReLU);
90 
91 /// Fills the input `Tensor` with values according to the method
92 /// described in "Delving deep into rectifiers: Surpassing human-level
93 /// performance on ImageNet classification" - He, K. et al. (2015), using a
94 /// uniform distribution. Also known as He initialization.
95 /// No gradient will be recorded for this operation.
96 TORCH_API Tensor kaiming_uniform_(
97     Tensor tensor,
98     double a = 0,
99     FanModeType mode = torch::kFanIn,
100     NonlinearityType nonlinearity = torch::kLeakyReLU);
101 
102 /// Fills the input `Tensor` with values according to the method
103 /// described in "Understanding the difficulty of training deep feedforward
104 /// neural networks" - Glorot, X. & Bengio, Y. (2010). Values are scaled by the
105 /// `gain` parameter. No gradient will be recorded for this operation.
106 TORCH_API Tensor xavier_normal_(Tensor tensor, double gain = 1.0);
107 
108 /// Fills the input `Tensor` with values according to the method
109 /// described in "Understanding the difficulty of training deep feedforward
110 /// neural networks" - Glorot, X. & Bengio, Y. (2010), using a uniform
111 /// distribution. Values are scaled by the `gain` parameter
112 /// No gradient will be recorded for this operation.
113 TORCH_API Tensor xavier_uniform_(Tensor tensor, double gain = 1.0);
114 
115 /// Fills the given `tensor` with zeros.
116 /// No gradient will be recorded for this operation.
117 TORCH_API Tensor zeros_(Tensor tensor);
118 
119 TORCH_API std::tuple<int64_t, int64_t> _calculate_fan_in_and_fan_out(
120     const Tensor& tensor);
121 
122 } // namespace init
123 } // namespace nn
124 } // namespace torch
125