xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/nn/modules/activation.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/nn/functional/activation.h>
2 #include <torch/nn/init.h>
3 #include <torch/nn/modules/activation.h>
4 
5 namespace F = torch::nn::functional;
6 
7 namespace torch {
8 namespace nn {
9 
ELUImpl(const ELUOptions & options_)10 ELUImpl::ELUImpl(const ELUOptions& options_) : options(options_) {}
11 
forward(Tensor input)12 Tensor ELUImpl::forward(Tensor input) {
13   return F::detail::elu(input, options.alpha(), options.inplace());
14 }
15 
reset()16 void ELUImpl::reset() {}
17 
pretty_print(std::ostream & stream) const18 void ELUImpl::pretty_print(std::ostream& stream) const {
19   stream << "torch::nn::ELU(alpha=" << options.alpha();
20   if (options.inplace()) {
21     stream << std::boolalpha << ", inplace=" << options.inplace();
22   }
23   stream << ")";
24 }
25 
26 // ============================================================================
27 
SELUImpl(const SELUOptions & options_)28 SELUImpl::SELUImpl(const SELUOptions& options_) : options(options_) {}
29 
forward(Tensor input)30 Tensor SELUImpl::forward(Tensor input) {
31   return F::detail::selu(input, options.inplace());
32 }
33 
reset()34 void SELUImpl::reset() {}
35 
pretty_print(std::ostream & stream) const36 void SELUImpl::pretty_print(std::ostream& stream) const {
37   stream << "torch::nn::SELU(";
38   if (options.inplace()) {
39     stream << std::boolalpha << "inplace=" << options.inplace();
40   }
41   stream << ")";
42 }
43 
44 // ============================================================================
45 
HardshrinkImpl(const HardshrinkOptions & options_)46 HardshrinkImpl::HardshrinkImpl(const HardshrinkOptions& options_)
47     : options(options_) {}
48 
forward(const Tensor & input)49 Tensor HardshrinkImpl::forward(const Tensor& input) {
50   return F::detail::hardshrink(input, options.lambda());
51 }
52 
reset()53 void HardshrinkImpl::reset() {}
54 
pretty_print(std::ostream & stream) const55 void HardshrinkImpl::pretty_print(std::ostream& stream) const {
56   stream << std::boolalpha << "torch::nn::Hardshrink(" << options.lambda()
57          << ")";
58 }
59 
60 // ============================================================================
61 
HardtanhImpl(const HardtanhOptions & options_)62 HardtanhImpl::HardtanhImpl(const HardtanhOptions& options_)
63     : options(options_) {
64   // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
65   reset();
66 }
67 
forward(Tensor input)68 Tensor HardtanhImpl::forward(Tensor input) {
69   return F::detail::hardtanh(
70       input, options.min_val(), options.max_val(), options.inplace());
71 }
72 
reset()73 void HardtanhImpl::reset() {
74   TORCH_CHECK(
75       options.max_val() > options.min_val(),
76       "max_val must be greater than min_val");
77 }
78 
pretty_print(std::ostream & stream) const79 void HardtanhImpl::pretty_print(std::ostream& stream) const {
80   stream << std::boolalpha
81          << "torch::nn::Hardtanh(min_val=" << options.min_val()
82          << ", max_val=" << options.max_val();
83   if (options.inplace()) {
84     stream << std::boolalpha << ", inplace=" << options.inplace();
85   }
86   stream << ")";
87 }
88 
89 // ============================================================================
90 
LeakyReLUImpl(const LeakyReLUOptions & options_)91 LeakyReLUImpl::LeakyReLUImpl(const LeakyReLUOptions& options_)
92     : options(options_) {}
93 
forward(Tensor input)94 Tensor LeakyReLUImpl::forward(Tensor input) {
95   return F::detail::leaky_relu(
96       input, options.negative_slope(), options.inplace());
97 }
98 
reset()99 void LeakyReLUImpl::reset() {}
100 
pretty_print(std::ostream & stream) const101 void LeakyReLUImpl::pretty_print(std::ostream& stream) const {
102   stream << std::boolalpha
103          << "torch::nn::LeakyReLU(negative_slope=" << options.negative_slope();
104   if (options.inplace()) {
105     stream << std::boolalpha << ", inplace=" << options.inplace();
106   }
107   stream << ")";
108 }
109 
110 // ============================================================================
111 
forward(const Tensor & input)112 Tensor LogSigmoidImpl::forward(const Tensor& input) {
113   return F::logsigmoid(input);
114 }
115 
reset()116 void LogSigmoidImpl::reset() {}
117 
pretty_print(std::ostream & stream) const118 void LogSigmoidImpl::pretty_print(std::ostream& stream) const {
119   stream << "torch::nn::LogSigmoid()";
120 }
121 
122 // ============================================================================
123 
SoftmaxImpl(const SoftmaxOptions & options_)124 SoftmaxImpl::SoftmaxImpl(const SoftmaxOptions& options_) : options(options_) {}
125 
reset()126 void SoftmaxImpl::reset() {}
127 
pretty_print(std::ostream & stream) const128 void SoftmaxImpl::pretty_print(std::ostream& stream) const {
129   stream << "torch::nn::Softmax(dim=" << options.dim() << ")";
130 }
131 
forward(const Tensor & input)132 Tensor SoftmaxImpl::forward(const Tensor& input) {
133   return F::detail::softmax(input, options.dim(), std::nullopt);
134 }
135 
136 // ============================================================================
137 
SoftminImpl(const SoftminOptions & options_)138 SoftminImpl::SoftminImpl(const SoftminOptions& options_) : options(options_) {}
139 
reset()140 void SoftminImpl::reset() {}
141 
pretty_print(std::ostream & stream) const142 void SoftminImpl::pretty_print(std::ostream& stream) const {
143   stream << "torch::nn::Softmin(dim=" << options.dim() << ")";
144 }
145 
forward(const Tensor & input)146 Tensor SoftminImpl::forward(const Tensor& input) {
147   return F::detail::softmin(input, options.dim(), std::nullopt);
148 }
149 
150 // ============================================================================
151 
LogSoftmaxImpl(const LogSoftmaxOptions & options_)152 LogSoftmaxImpl::LogSoftmaxImpl(const LogSoftmaxOptions& options_)
153     : options(options_) {}
154 
reset()155 void LogSoftmaxImpl::reset() {}
156 
pretty_print(std::ostream & stream) const157 void LogSoftmaxImpl::pretty_print(std::ostream& stream) const {
158   stream << "torch::nn::LogSoftmax(dim=" << options.dim() << ")";
159 }
160 
forward(const Tensor & input)161 Tensor LogSoftmaxImpl::forward(const Tensor& input) {
162   return F::detail::log_softmax(input, options.dim(), std::nullopt);
163 }
164 
165 // ============================================================================
166 
reset()167 void Softmax2dImpl::reset() {}
168 
pretty_print(std::ostream & stream) const169 void Softmax2dImpl::pretty_print(std::ostream& stream) const {
170   stream << "torch::nn::Softmax2d()";
171 }
172 
forward(const Tensor & input)173 Tensor Softmax2dImpl::forward(const Tensor& input) {
174   TORCH_CHECK(
175       input.dim() == 4 || input.dim() == 3,
176       "Softmax2d requires a 3D or 4D tensor as input");
177   return F::detail::softmax(input, /*dim=*/-3, std::nullopt);
178 }
179 
180 // ============================================================================
181 
PReLUImpl(const PReLUOptions & options_)182 PReLUImpl::PReLUImpl(const PReLUOptions& options_) : options(options_) {
183   // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
184   reset();
185 }
186 
forward(const Tensor & input)187 Tensor PReLUImpl::forward(const Tensor& input) {
188   return F::prelu(input, weight);
189 }
190 
reset()191 void PReLUImpl::reset() {
192   weight = register_parameter(
193       "weight", torch::full(options.num_parameters(), options.init()));
194 }
195 
pretty_print(std::ostream & stream) const196 void PReLUImpl::pretty_print(std::ostream& stream) const {
197   stream << "torch::nn::PReLU(num_parameters=" << options.num_parameters()
198          << ")";
199 }
200 
201 // ============================================================================
202 
ReLUImpl(const ReLUOptions & options_)203 ReLUImpl::ReLUImpl(const ReLUOptions& options_) : options(options_) {}
204 
forward(Tensor input)205 Tensor ReLUImpl::forward(Tensor input) {
206   return F::detail::relu(input, options.inplace());
207 }
208 
reset()209 void ReLUImpl::reset() {}
210 
pretty_print(std::ostream & stream) const211 void ReLUImpl::pretty_print(std::ostream& stream) const {
212   stream << "torch::nn::ReLU(";
213   if (options.inplace()) {
214     stream << std::boolalpha << "inplace=" << options.inplace();
215   }
216   stream << ")";
217 }
218 
219 // ============================================================================
220 
ReLU6Impl(const ReLU6Options & options_)221 ReLU6Impl::ReLU6Impl(const ReLU6Options& options_) : options(options_) {}
222 
forward(Tensor input)223 Tensor ReLU6Impl::forward(Tensor input) {
224   return F::detail::relu6(input, options.inplace());
225 }
226 
reset()227 void ReLU6Impl::reset() {}
228 
pretty_print(std::ostream & stream) const229 void ReLU6Impl::pretty_print(std::ostream& stream) const {
230   stream << "torch::nn::ReLU6(";
231   if (options.inplace()) {
232     stream << std::boolalpha << "inplace=" << options.inplace();
233   }
234   stream << ")";
235 }
236 
237 // ============================================================================
238 
RReLUImpl(const RReLUOptions & options_)239 RReLUImpl::RReLUImpl(const RReLUOptions& options_) : options(options_) {}
240 
forward(Tensor input)241 Tensor RReLUImpl::forward(Tensor input) {
242   return F::detail::rrelu(
243       input,
244       options.lower(),
245       options.upper(),
246       is_training(),
247       options.inplace());
248 }
249 
reset()250 void RReLUImpl::reset() {}
251 
pretty_print(std::ostream & stream) const252 void RReLUImpl::pretty_print(std::ostream& stream) const {
253   stream << "torch::nn::RReLU(lower=" << options.lower()
254          << ", upper=" << options.upper();
255   if (options.inplace()) {
256     stream << std::boolalpha << ", inplace=" << options.inplace();
257   }
258   stream << ")";
259 }
260 
261 // ============================================================================
262 
CELUImpl(const CELUOptions & options_)263 CELUImpl::CELUImpl(const CELUOptions& options_) : options(options_) {}
264 
forward(Tensor input)265 Tensor CELUImpl::forward(Tensor input) {
266   return F::detail::celu(input, options.alpha(), options.inplace());
267 }
268 
reset()269 void CELUImpl::reset() {}
270 
pretty_print(std::ostream & stream) const271 void CELUImpl::pretty_print(std::ostream& stream) const {
272   stream << "torch::nn::CELU(alpha=" << options.alpha();
273   if (options.inplace()) {
274     stream << std::boolalpha << ", inplace=" << options.inplace();
275   }
276   stream << ")";
277 }
278 
279 // ============================================================================
280 
GLUImpl(const GLUOptions & options_)281 GLUImpl::GLUImpl(const GLUOptions& options_) : options(options_) {}
282 
forward(const Tensor & input)283 Tensor GLUImpl::forward(const Tensor& input) {
284   return F::detail::glu(input, options.dim());
285 }
286 
reset()287 void GLUImpl::reset() {}
288 
pretty_print(std::ostream & stream) const289 void GLUImpl::pretty_print(std::ostream& stream) const {
290   stream << "torch::nn::GLU(dim=" << options.dim() << ")";
291 }
292 
293 // ============================================================================
294 
GELUImpl(GELUOptions options_)295 GELUImpl::GELUImpl(GELUOptions options_) : options(std::move(options_)) {}
296 
forward(const Tensor & input)297 Tensor GELUImpl::forward(const Tensor& input) {
298   return F::detail::gelu(input, options.approximate());
299 }
300 
reset()301 void GELUImpl::reset() {}
302 
pretty_print(std::ostream & stream) const303 void GELUImpl::pretty_print(std::ostream& stream) const {
304   stream << "torch::nn::GELU()";
305 }
306 
307 // ============================================================================
308 
forward(const Tensor & input)309 Tensor SiLUImpl::forward(const Tensor& input) {
310   return F::silu(input);
311 }
312 
reset()313 void SiLUImpl::reset() {}
314 
pretty_print(std::ostream & stream) const315 void SiLUImpl::pretty_print(std::ostream& stream) const {
316   stream << "torch::nn::SiLU()";
317 }
318 
319 // ============================================================================
320 
forward(const Tensor & input)321 Tensor MishImpl::forward(const Tensor& input) {
322   return F::mish(input);
323 }
324 
reset()325 void MishImpl::reset() {}
326 
pretty_print(std::ostream & stream) const327 void MishImpl::pretty_print(std::ostream& stream) const {
328   stream << "torch::nn::Mish()";
329 }
330 
331 // ============================================================================
332 
forward(const Tensor & input)333 Tensor SigmoidImpl::forward(const Tensor& input) {
334   return torch::sigmoid(input);
335 }
336 
reset()337 void SigmoidImpl::reset() {}
338 
pretty_print(std::ostream & stream) const339 void SigmoidImpl::pretty_print(std::ostream& stream) const {
340   stream << "torch::nn::Sigmoid()";
341 }
342 
343 // ============================================================================
344 
SoftplusImpl(const SoftplusOptions & options_)345 SoftplusImpl::SoftplusImpl(const SoftplusOptions& options_)
346     : options(options_) {}
347 
forward(const Tensor & input)348 Tensor SoftplusImpl::forward(const Tensor& input) {
349   return F::detail::softplus(input, options.beta(), options.threshold());
350 }
351 
reset()352 void SoftplusImpl::reset() {}
353 
pretty_print(std::ostream & stream) const354 void SoftplusImpl::pretty_print(std::ostream& stream) const {
355   stream << "torch::nn::Softplus(beta=" << options.beta()
356          << ", threshold=" << options.threshold() << ")";
357 }
358 
359 // ============================================================================
360 
SoftshrinkImpl(const SoftshrinkOptions & options_)361 SoftshrinkImpl::SoftshrinkImpl(const SoftshrinkOptions& options_)
362     : options(options_) {}
363 
forward(const Tensor & input)364 Tensor SoftshrinkImpl::forward(const Tensor& input) {
365   return F::detail::softshrink(input, options.lambda());
366 }
367 
reset()368 void SoftshrinkImpl::reset() {}
369 
pretty_print(std::ostream & stream) const370 void SoftshrinkImpl::pretty_print(std::ostream& stream) const {
371   stream << "torch::nn::Softshrink(" << options.lambda() << ")";
372 }
373 
374 // ============================================================================
375 
forward(const Tensor & input)376 Tensor SoftsignImpl::forward(const Tensor& input) {
377   return F::softsign(input);
378 }
379 
reset()380 void SoftsignImpl::reset() {}
381 
pretty_print(std::ostream & stream) const382 void SoftsignImpl::pretty_print(std::ostream& stream) const {
383   stream << "torch::nn::Softsign()";
384 }
385 
386 // ============================================================================
387 
forward(const Tensor & input)388 Tensor TanhImpl::forward(const Tensor& input) {
389   return torch::tanh(input);
390 }
391 
reset()392 void TanhImpl::reset() {}
393 
pretty_print(std::ostream & stream) const394 void TanhImpl::pretty_print(std::ostream& stream) const {
395   stream << "torch::nn::Tanh()";
396 }
397 
398 // ============================================================================
399 
forward(const Tensor & input)400 Tensor TanhshrinkImpl::forward(const Tensor& input) {
401   return F::tanhshrink(input);
402 }
403 
reset()404 void TanhshrinkImpl::reset() {}
405 
pretty_print(std::ostream & stream) const406 void TanhshrinkImpl::pretty_print(std::ostream& stream) const {
407   stream << "torch::nn::Tanhshrink()";
408 }
409 
410 // ============================================================================
411 
ThresholdImpl(const ThresholdOptions & options_)412 ThresholdImpl::ThresholdImpl(const ThresholdOptions& options_)
413     : options(options_) {}
414 
forward(Tensor input)415 Tensor ThresholdImpl::forward(Tensor input) {
416   return F::detail::threshold(
417       input, options.threshold(), options.value(), options.inplace());
418 }
419 
reset()420 void ThresholdImpl::reset() {}
421 
pretty_print(std::ostream & stream) const422 void ThresholdImpl::pretty_print(std::ostream& stream) const {
423   stream << "torch::nn::Threshold(threshold=" << options.threshold()
424          << ", value=" << options.value();
425   if (options.inplace()) {
426     stream << std::boolalpha << ", inplace=" << options.inplace();
427   }
428   stream << ")";
429 }
430 
431 // ============================================================================
432 
MultiheadAttentionImpl(const MultiheadAttentionOptions & options_)433 MultiheadAttentionImpl::MultiheadAttentionImpl(
434     const MultiheadAttentionOptions& options_)
435     : Cloneable("torch::nn::MultiheadAttention"), options(options_) {
436   // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
437   reset();
438 }
439 
forward(const Tensor & query,const Tensor & key,const Tensor & value,const Tensor & key_padding_mask,bool need_weights,const Tensor & attn_mask,bool average_attn_weights)440 std::tuple<Tensor, Tensor> MultiheadAttentionImpl::forward(
441     const Tensor& query,
442     const Tensor& key,
443     const Tensor& value,
444     const Tensor& key_padding_mask,
445     bool need_weights,
446     const Tensor& attn_mask,
447     bool average_attn_weights) {
448   if (!_qkv_same_embed_dim) {
449     return F::multi_head_attention_forward(
450         query,
451         key,
452         value,
453         F::MultiheadAttentionForwardFuncOptions(
454             /*embed_dim_to_check=*/options.embed_dim(),
455             /*num_heads=*/options.num_heads(),
456             /*in_proj_weight=*/in_proj_weight,
457             /*in_proj_bias=*/in_proj_bias,
458             /*bias_k=*/bias_k,
459             /*bias_v=*/bias_v,
460             /*add_zero_attn=*/options.add_zero_attn(),
461             /*dropout_p=*/options.dropout(),
462             /*out_proj_weight=*/out_proj->weight,
463             /*out_proj_bias=*/out_proj->bias)
464             .training(is_training())
465             .key_padding_mask(key_padding_mask)
466             .need_weights(need_weights)
467             .attn_mask(attn_mask)
468             .use_separate_proj_weight(true)
469             .q_proj_weight(q_proj_weight)
470             .k_proj_weight(k_proj_weight)
471             .v_proj_weight(v_proj_weight)
472             .average_attn_weights(average_attn_weights));
473   } else {
474     return F::multi_head_attention_forward(
475         query,
476         key,
477         value,
478         F::MultiheadAttentionForwardFuncOptions(
479             /*embed_dim_to_check=*/options.embed_dim(),
480             /*num_heads=*/options.num_heads(),
481             /*in_proj_weight=*/in_proj_weight,
482             /*in_proj_bias=*/in_proj_bias,
483             /*bias_k=*/bias_k,
484             /*bias_v=*/bias_v,
485             /*add_zero_attn=*/options.add_zero_attn(),
486             /*dropout_p=*/options.dropout(),
487             /*out_proj_weight=*/out_proj->weight,
488             /*out_proj_bias=*/out_proj->bias)
489             .training(is_training())
490             .key_padding_mask(key_padding_mask)
491             .need_weights(need_weights)
492             .attn_mask(attn_mask)
493             .average_attn_weights(average_attn_weights));
494   }
495 }
496 
reset()497 void MultiheadAttentionImpl::reset() {
498   _qkv_same_embed_dim = options.kdim() == options.embed_dim() &&
499       options.vdim() == options.embed_dim();
500   head_dim = options.embed_dim() / options.num_heads();
501   TORCH_CHECK(
502       head_dim * options.num_heads() == options.embed_dim(),
503       "embed_dim must be divisible by num_heads");
504   if (!_qkv_same_embed_dim) {
505     q_proj_weight = register_parameter(
506         "q_proj_weight",
507         torch::empty({options.embed_dim(), options.embed_dim()}));
508     k_proj_weight = register_parameter(
509         "k_proj_weight", torch::empty({options.embed_dim(), options.kdim()}));
510     v_proj_weight = register_parameter(
511         "v_proj_weight", torch::empty({options.embed_dim(), options.vdim()}));
512     register_parameter("in_proj_weight", {}, /*requires_grad=*/false);
513   } else {
514     in_proj_weight = register_parameter(
515         "in_proj_weight",
516         torch::empty({3 * options.embed_dim(), options.embed_dim()}));
517     register_parameter("q_proj_weight", {}, /*requires_grad=*/false);
518     register_parameter("k_proj_weight", {}, /*requires_grad=*/false);
519     register_parameter("v_proj_weight", {}, /*requires_grad=*/false);
520   }
521   if (options.bias()) {
522     in_proj_bias = register_parameter(
523         "in_proj_bias", torch::empty(3 * options.embed_dim()));
524   } else {
525     register_parameter("in_proj_bias", {}, /*requires_grad=*/false);
526   }
527   out_proj = register_module(
528       "out_proj",
529       Linear(LinearOptions(options.embed_dim(), options.embed_dim())
530                  .bias(options.bias())));
531   if (options.add_bias_kv()) {
532     bias_k =
533         register_parameter("bias_k", torch::empty({1, 1, options.embed_dim()}));
534     bias_v =
535         register_parameter("bias_v", torch::empty({1, 1, options.embed_dim()}));
536   } else {
537     bias_k.reset();
538     bias_v.reset();
539   }
540   _reset_parameters();
541 }
542 
_reset_parameters()543 void MultiheadAttentionImpl::_reset_parameters() {
544   using namespace torch::nn::init;
545   if (_qkv_same_embed_dim) {
546     xavier_uniform_(in_proj_weight);
547   } else {
548     xavier_uniform_(q_proj_weight);
549     xavier_uniform_(k_proj_weight);
550     xavier_uniform_(v_proj_weight);
551   }
552   if (in_proj_bias.defined()) {
553     constant_(in_proj_bias, 0.);
554     constant_(out_proj->bias, 0.);
555   }
556   if (bias_k.defined()) {
557     xavier_normal_(bias_k);
558   }
559   if (bias_v.defined()) {
560     xavier_normal_(bias_v);
561   }
562 }
563 
564 } // namespace nn
565 } // namespace torch
566