xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/nn/init.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/nn/init.h>
2 
3 #include <torch/linalg.h>
4 #include <torch/types.h>
5 #include <torch/utils.h>
6 
7 #include <ATen/ATen.h>
8 #include <c10/util/Exception.h>
9 #include <c10/util/irange.h>
10 
11 #include <algorithm>
12 #include <cmath>
13 #include <cstddef>
14 #include <tuple>
15 
16 namespace torch {
17 namespace nn {
18 namespace init {
19 namespace {
20 struct Fan {
Fantorch::nn::init::__anon275c2cb80111::Fan21   explicit Fan(Tensor& tensor) {
22     const auto dimensions = tensor.ndimension();
23     TORCH_CHECK(
24         dimensions >= 2,
25         "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions");
26 
27     if (dimensions == 2) {
28       in = tensor.size(1);
29       out = tensor.size(0);
30     } else {
31       in = tensor.size(1) * tensor[0][0].numel();
32       out = tensor.size(0) * tensor[0][0].numel();
33     }
34   }
35 
36   int64_t in;
37   int64_t out;
38 };
39 
calculate_kaiming_std(Tensor tensor,double a,FanModeType mode,NonlinearityType nonlinearity)40 double calculate_kaiming_std(
41     Tensor tensor,
42     double a,
43     FanModeType mode,
44     NonlinearityType nonlinearity) {
45   NoGradGuard guard;
46   Fan fan(tensor);
47   const auto gain = calculate_gain(nonlinearity, a);
48   double std = 0.0;
49 
50   if (std::holds_alternative<enumtype::kFanIn>(mode)) {
51     std = gain / std::sqrt(fan.in);
52   } else {
53     std = gain / std::sqrt(fan.out);
54   }
55   return std;
56 }
57 } // namespace
58 
calculate_gain(NonlinearityType nonlinearity,double param)59 double calculate_gain(NonlinearityType nonlinearity, double param) {
60   if (std::holds_alternative<enumtype::kTanh>(nonlinearity)) {
61     return 5.0 / 3.0; // NOLINT
62   } else if (std::holds_alternative<enumtype::kReLU>(nonlinearity)) {
63     return std::sqrt(2.0); // NOLINT
64   } else if (std::holds_alternative<enumtype::kLeakyReLU>(nonlinearity)) {
65     return std::sqrt(2.0 / (1 + pow(param, 2))); // NOLINT
66   }
67 
68   return 1.0;
69 }
70 
constant_(Tensor tensor,Scalar value)71 Tensor constant_(Tensor tensor, Scalar value) {
72   NoGradGuard guard;
73   return tensor.fill_(value);
74 }
75 
dirac_(Tensor tensor)76 Tensor dirac_(Tensor tensor) {
77   NoGradGuard guard;
78 
79   TORCH_CHECK(
80       tensor.ndimension() >= 3 && tensor.ndimension() <= 5,
81       "Only tensors with 3, 4, or 5 dimensions are supported");
82 
83   const auto sizes = tensor.sizes();
84   const auto min_dim = std::min(sizes[0], sizes[1]);
85 
86   tensor.zero_();
87   for (const auto d : c10::irange(min_dim)) {
88     switch (tensor.ndimension()) {
89       case 3: // Temporal convolution
90         tensor[d][d][sizes[2] / 2] = 1;
91         break;
92       case 4: // Spatial convolution
93         tensor[d][d][sizes[2] / 2][sizes[3] / 2] = 1;
94         break;
95       case 5: // Volumetric convolution
96         tensor[d][d][sizes[2] / 2][sizes[3] / 2][sizes[4] / 2] = 1;
97         break;
98     }
99   }
100 
101   return tensor;
102 }
103 
eye_(Tensor matrix)104 Tensor eye_(Tensor matrix) {
105   NoGradGuard guard;
106   TORCH_CHECK(
107       matrix.ndimension() == 2, "Only tensors with 2 dimensions are supported");
108   return torch::eye_out(matrix, matrix.size(0), matrix.size(1));
109 }
110 
normal_(Tensor tensor,double mean,double std)111 Tensor normal_(Tensor tensor, double mean, double std) {
112   NoGradGuard guard;
113   return tensor.normal_(mean, std);
114 }
115 
ones_(Tensor tensor)116 Tensor ones_(Tensor tensor) {
117   NoGradGuard guard;
118   return tensor.fill_(1);
119 }
120 
orthogonal_(Tensor tensor,double gain)121 Tensor orthogonal_(Tensor tensor, double gain) {
122   NoGradGuard guard;
123 
124   TORCH_CHECK(
125       tensor.ndimension() >= 2,
126       "Only tensors with 2 or more dimensions are supported");
127 
128   const auto rows = tensor.size(0);
129   const auto columns = tensor.numel() / rows;
130   auto flattened = torch::randn({rows, columns});
131 
132   if (rows < columns) {
133     flattened.t_();
134   }
135 
136   // Compute the qr factorization
137   auto [q, r] = torch::linalg::qr(flattened);
138   // Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
139   auto d = torch::diag(r, 0);
140   auto ph = d.sign();
141   q *= ph;
142 
143   if (rows < columns) {
144     q.t_();
145   }
146 
147   tensor.view_as(q).copy_(q);
148   tensor.mul_(gain);
149 
150   return tensor;
151 }
152 
sparse_(Tensor tensor,double sparsity,double std)153 Tensor sparse_(Tensor tensor, double sparsity, double std) {
154   NoGradGuard guard;
155 
156   TORCH_CHECK(
157       tensor.ndimension() == 2, "Only tensors with 2 dimensions are supported");
158 
159   const auto rows = tensor.size(0);
160   const auto columns = tensor.size(1);
161   const int64_t num_zeros = std::ceil(sparsity * rows);
162   tensor.normal_(0, std);
163   for (const auto column : c10::irange(columns)) {
164     auto row_indices = torch::randperm(rows, tensor.options().dtype(kLong));
165     auto zero_indices =
166         row_indices.slice(/*dim=*/0, /*start=*/0, /*end=*/num_zeros);
167     tensor.index_put_(
168         {zero_indices, torch::tensor(column, tensor.options().dtype(kLong))},
169         torch::zeros(num_zeros, tensor.options()));
170   }
171 
172   return tensor;
173 }
174 
uniform_(Tensor tensor,double low,double high)175 Tensor uniform_(Tensor tensor, double low, double high) {
176   NoGradGuard guard;
177   return tensor.uniform_(low, high);
178 }
179 
kaiming_uniform_(Tensor tensor,double a,FanModeType mode,NonlinearityType nonlinearity)180 Tensor kaiming_uniform_(
181     Tensor tensor,
182     double a,
183     FanModeType mode,
184     NonlinearityType nonlinearity) {
185   NoGradGuard guard;
186   auto std = calculate_kaiming_std(tensor, a, mode, nonlinearity);
187   // Calculate uniform bounds from standard deviation
188   const auto bound = std::sqrt(3.0) * std;
189   return tensor.uniform_(-bound, bound);
190 }
191 
kaiming_normal_(Tensor tensor,double a,FanModeType mode,NonlinearityType nonlinearity)192 Tensor kaiming_normal_(
193     Tensor tensor,
194     double a,
195     FanModeType mode,
196     NonlinearityType nonlinearity) {
197   NoGradGuard guard;
198 
199   auto std = calculate_kaiming_std(tensor, a, mode, nonlinearity);
200   return tensor.normal_(0, std);
201 }
202 
xavier_normal_(Tensor tensor,double gain)203 Tensor xavier_normal_(Tensor tensor, double gain) {
204   NoGradGuard guard;
205 
206   Fan fan(tensor);
207   // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
208   const auto std = gain * std::sqrt(2.0 / (fan.in + fan.out));
209   return tensor.normal_(0, std);
210 }
211 
xavier_uniform_(Tensor tensor,double gain)212 Tensor xavier_uniform_(Tensor tensor, double gain) {
213   NoGradGuard guard;
214   Fan fan(tensor);
215   // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
216   const auto std = gain * std::sqrt(2.0 / (fan.in + fan.out));
217   // Calculate uniform bounds from standard deviation with
218   const auto a = std::sqrt(3.0) * std;
219   return tensor.uniform_(-a, a);
220 }
221 
zeros_(Tensor tensor)222 Tensor zeros_(Tensor tensor) {
223   NoGradGuard guard;
224   return tensor.zero_();
225 }
226 
_calculate_fan_in_and_fan_out(const Tensor & tensor)227 std::tuple<int64_t, int64_t> _calculate_fan_in_and_fan_out(
228     const Tensor& tensor) {
229   const auto dimensions = tensor.dim();
230   TORCH_CHECK(
231       dimensions >= 2,
232       "Fan in and fan out can not be computed "
233       "for tensor with fewer than 2 dimensions")
234 
235   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
236   int64_t fan_in, fan_out;
237   if (dimensions == 2) { // Linear
238     fan_in = tensor.size(1);
239     fan_out = tensor.size(0);
240   } else {
241     const auto num_input_fmaps = tensor.size(1);
242     const auto num_output_fmaps = tensor.size(0);
243     auto receptive_field_size = 1;
244     if (tensor.dim() > 2) {
245       receptive_field_size = tensor[0][0].numel();
246     }
247     fan_in = num_input_fmaps * receptive_field_size;
248     fan_out = num_output_fmaps * receptive_field_size;
249   }
250   return std::tie(fan_in, fan_out);
251 }
252 
253 } // namespace init
254 } // namespace nn
255 } // namespace torch
256