1 #include <torch/nn/init.h>
2
3 #include <torch/linalg.h>
4 #include <torch/types.h>
5 #include <torch/utils.h>
6
7 #include <ATen/ATen.h>
8 #include <c10/util/Exception.h>
9 #include <c10/util/irange.h>
10
11 #include <algorithm>
12 #include <cmath>
13 #include <cstddef>
14 #include <tuple>
15
16 namespace torch {
17 namespace nn {
18 namespace init {
19 namespace {
20 struct Fan {
Fantorch::nn::init::__anon275c2cb80111::Fan21 explicit Fan(Tensor& tensor) {
22 const auto dimensions = tensor.ndimension();
23 TORCH_CHECK(
24 dimensions >= 2,
25 "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions");
26
27 if (dimensions == 2) {
28 in = tensor.size(1);
29 out = tensor.size(0);
30 } else {
31 in = tensor.size(1) * tensor[0][0].numel();
32 out = tensor.size(0) * tensor[0][0].numel();
33 }
34 }
35
36 int64_t in;
37 int64_t out;
38 };
39
calculate_kaiming_std(Tensor tensor,double a,FanModeType mode,NonlinearityType nonlinearity)40 double calculate_kaiming_std(
41 Tensor tensor,
42 double a,
43 FanModeType mode,
44 NonlinearityType nonlinearity) {
45 NoGradGuard guard;
46 Fan fan(tensor);
47 const auto gain = calculate_gain(nonlinearity, a);
48 double std = 0.0;
49
50 if (std::holds_alternative<enumtype::kFanIn>(mode)) {
51 std = gain / std::sqrt(fan.in);
52 } else {
53 std = gain / std::sqrt(fan.out);
54 }
55 return std;
56 }
57 } // namespace
58
calculate_gain(NonlinearityType nonlinearity,double param)59 double calculate_gain(NonlinearityType nonlinearity, double param) {
60 if (std::holds_alternative<enumtype::kTanh>(nonlinearity)) {
61 return 5.0 / 3.0; // NOLINT
62 } else if (std::holds_alternative<enumtype::kReLU>(nonlinearity)) {
63 return std::sqrt(2.0); // NOLINT
64 } else if (std::holds_alternative<enumtype::kLeakyReLU>(nonlinearity)) {
65 return std::sqrt(2.0 / (1 + pow(param, 2))); // NOLINT
66 }
67
68 return 1.0;
69 }
70
constant_(Tensor tensor,Scalar value)71 Tensor constant_(Tensor tensor, Scalar value) {
72 NoGradGuard guard;
73 return tensor.fill_(value);
74 }
75
dirac_(Tensor tensor)76 Tensor dirac_(Tensor tensor) {
77 NoGradGuard guard;
78
79 TORCH_CHECK(
80 tensor.ndimension() >= 3 && tensor.ndimension() <= 5,
81 "Only tensors with 3, 4, or 5 dimensions are supported");
82
83 const auto sizes = tensor.sizes();
84 const auto min_dim = std::min(sizes[0], sizes[1]);
85
86 tensor.zero_();
87 for (const auto d : c10::irange(min_dim)) {
88 switch (tensor.ndimension()) {
89 case 3: // Temporal convolution
90 tensor[d][d][sizes[2] / 2] = 1;
91 break;
92 case 4: // Spatial convolution
93 tensor[d][d][sizes[2] / 2][sizes[3] / 2] = 1;
94 break;
95 case 5: // Volumetric convolution
96 tensor[d][d][sizes[2] / 2][sizes[3] / 2][sizes[4] / 2] = 1;
97 break;
98 }
99 }
100
101 return tensor;
102 }
103
eye_(Tensor matrix)104 Tensor eye_(Tensor matrix) {
105 NoGradGuard guard;
106 TORCH_CHECK(
107 matrix.ndimension() == 2, "Only tensors with 2 dimensions are supported");
108 return torch::eye_out(matrix, matrix.size(0), matrix.size(1));
109 }
110
normal_(Tensor tensor,double mean,double std)111 Tensor normal_(Tensor tensor, double mean, double std) {
112 NoGradGuard guard;
113 return tensor.normal_(mean, std);
114 }
115
ones_(Tensor tensor)116 Tensor ones_(Tensor tensor) {
117 NoGradGuard guard;
118 return tensor.fill_(1);
119 }
120
orthogonal_(Tensor tensor,double gain)121 Tensor orthogonal_(Tensor tensor, double gain) {
122 NoGradGuard guard;
123
124 TORCH_CHECK(
125 tensor.ndimension() >= 2,
126 "Only tensors with 2 or more dimensions are supported");
127
128 const auto rows = tensor.size(0);
129 const auto columns = tensor.numel() / rows;
130 auto flattened = torch::randn({rows, columns});
131
132 if (rows < columns) {
133 flattened.t_();
134 }
135
136 // Compute the qr factorization
137 auto [q, r] = torch::linalg::qr(flattened);
138 // Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
139 auto d = torch::diag(r, 0);
140 auto ph = d.sign();
141 q *= ph;
142
143 if (rows < columns) {
144 q.t_();
145 }
146
147 tensor.view_as(q).copy_(q);
148 tensor.mul_(gain);
149
150 return tensor;
151 }
152
sparse_(Tensor tensor,double sparsity,double std)153 Tensor sparse_(Tensor tensor, double sparsity, double std) {
154 NoGradGuard guard;
155
156 TORCH_CHECK(
157 tensor.ndimension() == 2, "Only tensors with 2 dimensions are supported");
158
159 const auto rows = tensor.size(0);
160 const auto columns = tensor.size(1);
161 const int64_t num_zeros = std::ceil(sparsity * rows);
162 tensor.normal_(0, std);
163 for (const auto column : c10::irange(columns)) {
164 auto row_indices = torch::randperm(rows, tensor.options().dtype(kLong));
165 auto zero_indices =
166 row_indices.slice(/*dim=*/0, /*start=*/0, /*end=*/num_zeros);
167 tensor.index_put_(
168 {zero_indices, torch::tensor(column, tensor.options().dtype(kLong))},
169 torch::zeros(num_zeros, tensor.options()));
170 }
171
172 return tensor;
173 }
174
uniform_(Tensor tensor,double low,double high)175 Tensor uniform_(Tensor tensor, double low, double high) {
176 NoGradGuard guard;
177 return tensor.uniform_(low, high);
178 }
179
kaiming_uniform_(Tensor tensor,double a,FanModeType mode,NonlinearityType nonlinearity)180 Tensor kaiming_uniform_(
181 Tensor tensor,
182 double a,
183 FanModeType mode,
184 NonlinearityType nonlinearity) {
185 NoGradGuard guard;
186 auto std = calculate_kaiming_std(tensor, a, mode, nonlinearity);
187 // Calculate uniform bounds from standard deviation
188 const auto bound = std::sqrt(3.0) * std;
189 return tensor.uniform_(-bound, bound);
190 }
191
kaiming_normal_(Tensor tensor,double a,FanModeType mode,NonlinearityType nonlinearity)192 Tensor kaiming_normal_(
193 Tensor tensor,
194 double a,
195 FanModeType mode,
196 NonlinearityType nonlinearity) {
197 NoGradGuard guard;
198
199 auto std = calculate_kaiming_std(tensor, a, mode, nonlinearity);
200 return tensor.normal_(0, std);
201 }
202
xavier_normal_(Tensor tensor,double gain)203 Tensor xavier_normal_(Tensor tensor, double gain) {
204 NoGradGuard guard;
205
206 Fan fan(tensor);
207 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
208 const auto std = gain * std::sqrt(2.0 / (fan.in + fan.out));
209 return tensor.normal_(0, std);
210 }
211
xavier_uniform_(Tensor tensor,double gain)212 Tensor xavier_uniform_(Tensor tensor, double gain) {
213 NoGradGuard guard;
214 Fan fan(tensor);
215 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
216 const auto std = gain * std::sqrt(2.0 / (fan.in + fan.out));
217 // Calculate uniform bounds from standard deviation with
218 const auto a = std::sqrt(3.0) * std;
219 return tensor.uniform_(-a, a);
220 }
221
zeros_(Tensor tensor)222 Tensor zeros_(Tensor tensor) {
223 NoGradGuard guard;
224 return tensor.zero_();
225 }
226
_calculate_fan_in_and_fan_out(const Tensor & tensor)227 std::tuple<int64_t, int64_t> _calculate_fan_in_and_fan_out(
228 const Tensor& tensor) {
229 const auto dimensions = tensor.dim();
230 TORCH_CHECK(
231 dimensions >= 2,
232 "Fan in and fan out can not be computed "
233 "for tensor with fewer than 2 dimensions")
234
235 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
236 int64_t fan_in, fan_out;
237 if (dimensions == 2) { // Linear
238 fan_in = tensor.size(1);
239 fan_out = tensor.size(0);
240 } else {
241 const auto num_input_fmaps = tensor.size(1);
242 const auto num_output_fmaps = tensor.size(0);
243 auto receptive_field_size = 1;
244 if (tensor.dim() > 2) {
245 receptive_field_size = tensor[0][0].numel();
246 }
247 fan_in = num_input_fmaps * receptive_field_size;
248 fan_out = num_output_fmaps * receptive_field_size;
249 }
250 return std::tie(fan_in, fan_out);
251 }
252
253 } // namespace init
254 } // namespace nn
255 } // namespace torch
256