xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/functional/loss.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/ExpandUtils.h>
4 #include <torch/nn/functional/activation.h>
5 #include <torch/nn/options/loss.h>
6 
7 namespace torch {
8 namespace nn {
9 namespace functional {
10 
11 #ifndef DOXYGEN_SHOULD_SKIP_THIS
12 namespace detail {
l1_loss(const Tensor & input,const Tensor & target,L1LossFuncOptions::reduction_t reduction)13 inline Tensor l1_loss(
14     const Tensor& input,
15     const Tensor& target,
16     L1LossFuncOptions::reduction_t reduction) {
17   return torch::l1_loss(input, target, enumtype::reduction_get_enum(reduction));
18 }
19 } // namespace detail
20 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
21 
22 /// See
23 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.l1_loss
24 /// about the exact behavior of this functional.
25 ///
26 /// See the documentation for `torch::nn::functional::L1LossFuncOptions` class
27 /// to learn what optional arguments are supported for this functional.
28 ///
29 /// Example:
30 /// ```
31 /// namespace F = torch::nn::functional;
32 /// F::l1_loss(input, target, F::L1LossFuncOptions(torch::kNone));
33 /// ```
34 inline Tensor l1_loss(
35     const Tensor& input,
36     const Tensor& target,
37     const L1LossFuncOptions& options = {}) {
38   return detail::l1_loss(input, target, options.reduction());
39 }
40 
41 // ============================================================================
42 
43 #ifndef DOXYGEN_SHOULD_SKIP_THIS
44 namespace detail {
45 inline Tensor kl_div(
46     const Tensor& input,
47     const Tensor& target,
48     KLDivFuncOptions::reduction_t reduction,
49     bool log_target = false) {
50   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
51   torch::Reduction::Reduction reduction_enum;
52 
53   if (std::holds_alternative<enumtype::kMean>(reduction)) {
54     TORCH_WARN(
55         "reduction: 'mean' divides the total loss by both the batch size and the support size."
56         "'batchmean' divides only by the batch size, and aligns with the KL div math definition."
57         "'mean' will be changed to behave the same as 'batchmean' in the next major release.");
58   }
59 
60   // special case for batchmean
61   if (std::holds_alternative<enumtype::kBatchMean>(reduction)) {
62     reduction_enum = torch::Reduction::Sum;
63   } else {
64     reduction_enum = enumtype::reduction_get_enum(reduction);
65   }
66 
67   auto reduced = torch::kl_div(input, target, reduction_enum, log_target);
68 
69   if (std::holds_alternative<enumtype::kBatchMean>(reduction) &&
70       input.dim() != 0) {
71     reduced = reduced / input.sizes()[0];
72   }
73 
74   return reduced;
75 }
76 } // namespace detail
77 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
78 
79 /// See
80 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.kl_div
81 /// about the exact behavior of this functional.
82 ///
83 /// See the documentation for `torch::nn::functional::KLDivFuncOptions` class to
84 /// learn what optional arguments are supported for this functional.
85 ///
86 /// Example:
87 /// ```
88 /// namespace F = torch::nn::functional;
89 /// F::kl_div(input, target,
90 /// F::KLDivFuncOptions.reduction(torch::kNone).log_target(false));
91 /// ```
92 inline Tensor kl_div(
93     const Tensor& input,
94     const Tensor& target,
95     const KLDivFuncOptions& options = {}) {
96   return detail::kl_div(
97       input, target, options.reduction(), options.log_target());
98 }
99 
100 // ============================================================================
101 
102 #ifndef DOXYGEN_SHOULD_SKIP_THIS
103 namespace detail {
mse_loss(const Tensor & input,const Tensor & target,MSELossFuncOptions::reduction_t reduction)104 inline Tensor mse_loss(
105     const Tensor& input,
106     const Tensor& target,
107     MSELossFuncOptions::reduction_t reduction) {
108   if (!(target.sizes() == input.sizes())) {
109     TORCH_WARN(
110         "Using a target size (",
111         target.sizes(),
112         ") that is different to the input size (",
113         input.sizes(),
114         "). ",
115         "This will likely lead to incorrect results due to broadcasting. ",
116         "Please ensure they have the same size.");
117   }
118   std::vector<torch::Tensor> broadcast_tensors =
119       torch::broadcast_tensors({input, target});
120   auto expanded_input = broadcast_tensors[0];
121   auto expanded_target = broadcast_tensors[1];
122   return torch::mse_loss(
123       expanded_input, expanded_target, enumtype::reduction_get_enum(reduction));
124 }
125 } // namespace detail
126 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
127 
128 /// See
129 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.mse_loss
130 /// about the exact behavior of this functional.
131 ///
132 /// See the documentation for `torch::nn::functional::MSELossFuncOptions` class
133 /// to learn what optional arguments are supported for this functional.
134 ///
135 /// Example:
136 /// ```
137 /// namespace F = torch::nn::functional;
138 /// F::mse_loss(input, target, F::MSELossFuncOptions(torch::kNone));
139 /// ```
140 inline Tensor mse_loss(
141     const Tensor& input,
142     const Tensor& target,
143     const MSELossFuncOptions& options = {}) {
144   return detail::mse_loss(input, target, options.reduction());
145 }
146 
147 // ============================================================================
148 
149 #ifndef DOXYGEN_SHOULD_SKIP_THIS
150 namespace detail {
binary_cross_entropy(const Tensor & input,const Tensor & target,const Tensor & weight,BinaryCrossEntropyFuncOptions::reduction_t reduction)151 inline Tensor binary_cross_entropy(
152     const Tensor& input,
153     const Tensor& target,
154     const Tensor& weight,
155     BinaryCrossEntropyFuncOptions::reduction_t reduction) {
156   auto reduction_enum = enumtype::reduction_get_enum(reduction);
157 
158   if (target.sizes() != input.sizes()) {
159     TORCH_CHECK(
160         false,
161         "Using a target size (",
162         target.sizes(),
163         ") ",
164         "that is different to the input size (",
165         input.sizes(),
166         ") is deprecated. ",
167         "Please ensure they have the same size.");
168   }
169 
170   auto weight_ = weight;
171   if (weight_.defined()) {
172     auto new_size = at::infer_size(target.sizes(), weight_.sizes());
173     weight_ = weight_.expand(new_size);
174   }
175 
176   return torch::binary_cross_entropy(input, target, weight_, reduction_enum);
177 }
178 } // namespace detail
179 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
180 
181 /// See
182 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.binary_cross_entropy
183 /// about the exact behavior of this functional.
184 ///
185 /// See the documentation for
186 /// `torch::nn::functional::BinaryCrossEntropyFuncOptions` class to learn what
187 /// optional arguments are supported for this functional.
188 ///
189 /// Example:
190 /// ```
191 /// namespace F = torch::nn::functional;
192 /// F::binary_cross_entropy(input, target,
193 /// F::BinaryCrossEntropyFuncOptions().weight(weight));
194 /// ```
195 inline Tensor binary_cross_entropy(
196     const Tensor& input,
197     const Tensor& target,
198     const BinaryCrossEntropyFuncOptions& options = {}) {
199   return detail::binary_cross_entropy(
200       input, target, options.weight(), options.reduction());
201 }
202 
203 // ============================================================================
204 
205 #ifndef DOXYGEN_SHOULD_SKIP_THIS
206 namespace detail {
hinge_embedding_loss(const Tensor & input,const Tensor & target,double margin,HingeEmbeddingLossFuncOptions::reduction_t reduction)207 inline Tensor hinge_embedding_loss(
208     const Tensor& input,
209     const Tensor& target,
210     double margin,
211     HingeEmbeddingLossFuncOptions::reduction_t reduction) {
212   return torch::hinge_embedding_loss(
213       input, target, margin, enumtype::reduction_get_enum(reduction));
214 }
215 } // namespace detail
216 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
217 
218 /// See
219 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.hinge_embedding_loss
220 /// about the exact behavior of this functional.
221 ///
222 /// See the documentation for
223 /// `torch::nn::functional::HingeEmbeddingLossFuncOptions` class to learn what
224 /// optional arguments are supported for this functional.
225 ///
226 /// Example:
227 /// ```
228 /// namespace F = torch::nn::functional;
229 /// F::hinge_embedding_loss(input, target,
230 /// F::HingeEmbeddingLossFuncOptions().margin(2));
231 /// ```
232 inline Tensor hinge_embedding_loss(
233     const Tensor& input,
234     const Tensor& target,
235     const HingeEmbeddingLossFuncOptions& options = {}) {
236   return detail::hinge_embedding_loss(
237       input, target, options.margin(), options.reduction());
238 }
239 
240 // ============================================================================
241 
242 #ifndef DOXYGEN_SHOULD_SKIP_THIS
243 namespace detail {
multi_margin_loss(const Tensor & input,const Tensor & target,int64_t p,double margin,const Tensor & weight,MultiMarginLossFuncOptions::reduction_t reduction)244 inline Tensor multi_margin_loss(
245     const Tensor& input,
246     const Tensor& target,
247     int64_t p,
248     double margin,
249     const Tensor& weight,
250     MultiMarginLossFuncOptions::reduction_t reduction) {
251   TORCH_CHECK(p == 1 || p == 2, "only p == 1 and p == 2 supported");
252   if (weight.defined()) {
253     TORCH_CHECK(weight.dim() == 1, "weight must be one-dimensional");
254   }
255 
256   return torch::multi_margin_loss(
257       input,
258       target,
259       p,
260       margin,
261       weight,
262       enumtype::reduction_get_enum(reduction));
263 }
264 } // namespace detail
265 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
266 
267 /// See
268 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.multi_margin_loss
269 /// about the exact behavior of this functional.
270 ///
271 /// See the documentation for
272 /// `torch::nn::functional::MultiMarginLossFuncOptions` class to learn what
273 /// optional arguments are supported for this functional.
274 ///
275 /// Example:
276 /// ```
277 /// namespace F = torch::nn::functional;
278 /// F::multi_margin_loss(input, target,
279 /// F::MultiMarginLossFuncOptions().margin(2).weight(weight));
280 /// ```
281 inline Tensor multi_margin_loss(
282     const Tensor& input,
283     const Tensor& target,
284     const MultiMarginLossFuncOptions& options = {}) {
285   return detail::multi_margin_loss(
286       input,
287       target,
288       options.p(),
289       options.margin(),
290       options.weight(),
291       options.reduction());
292 }
293 
294 // ============================================================================
295 
296 #ifndef DOXYGEN_SHOULD_SKIP_THIS
297 namespace detail {
cosine_embedding_loss(const Tensor & input1,const Tensor & input2,const Tensor & target,double margin,CosineEmbeddingLossFuncOptions::reduction_t reduction)298 inline Tensor cosine_embedding_loss(
299     const Tensor& input1,
300     const Tensor& input2,
301     const Tensor& target,
302     double margin,
303     CosineEmbeddingLossFuncOptions::reduction_t reduction) {
304   return torch::cosine_embedding_loss(
305       input1, input2, target, margin, enumtype::reduction_get_enum(reduction));
306 }
307 } // namespace detail
308 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
309 
310 /// See
311 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.cosine_embedding_loss
312 /// about the exact behavior of this functional.
313 ///
314 /// See the documentation for
315 /// `torch::nn::functional::CosineEmbeddingLossFuncOptions` class to learn what
316 /// optional arguments are supported for this functional.
317 ///
318 /// Example:
319 /// ```
320 /// namespace F = torch::nn::functional;
321 /// F::cosine_embedding_loss(input1, input2, target,
322 /// F::CosineEmbeddingLossFuncOptions().margin(0.5));
323 /// ```
324 inline Tensor cosine_embedding_loss(
325     const Tensor& input1,
326     const Tensor& input2,
327     const Tensor& target,
328     const CosineEmbeddingLossFuncOptions& options = {}) {
329   return detail::cosine_embedding_loss(
330       input1, input2, target, options.margin(), options.reduction());
331 }
332 
333 // ============================================================================
334 
335 inline Tensor _smooth_l1_loss(
336     const Tensor& input,
337     const Tensor& target,
338     double beta = 1.) {
339   auto t = torch::abs(input - target);
340   return torch::where(t < beta, 0.5 * torch::pow(t, 2) / beta, t - 0.5 * beta);
341 }
342 
343 #ifndef DOXYGEN_SHOULD_SKIP_THIS
344 namespace detail {
345 inline Tensor smooth_l1_loss(
346     const Tensor& input,
347     const Tensor& target,
348     SmoothL1LossFuncOptions::reduction_t reduction,
349     std::optional<double> beta_opt = std::nullopt) {
350   if (target.sizes() != input.sizes()) {
351     TORCH_WARN(
352         "Using a target size (",
353         target.sizes(),
354         ") that is different to the input size (",
355         input.sizes(),
356         "). ",
357         "This will likely lead to incorrect results due to broadcasting. ",
358         "Please ensure they have the same size.");
359   }
360   double beta = beta_opt.value_or(1.0);
361 
362   std::vector<Tensor> expanded_tensors =
363       torch::broadcast_tensors({input, target});
364   return torch::smooth_l1_loss(
365       expanded_tensors[0],
366       expanded_tensors[1],
367       enumtype::reduction_get_enum(reduction),
368       beta);
369 }
370 } // namespace detail
371 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
372 
373 /// See
374 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.smooth_l1_loss
375 /// about the exact behavior of this functional.
376 ///
377 /// See the documentation for `torch::nn::functional::SmoothL1LossFuncOptions`
378 /// class to learn what optional arguments are supported for this functional.
379 ///
380 /// Example:
381 /// ```
382 /// namespace F = torch::nn::functional;
383 /// F::smooth_l1_loss(input, target, F::SmoothL1LossFuncOptions(torch::kNone));
384 /// ```
385 inline Tensor smooth_l1_loss(
386     const Tensor& input,
387     const Tensor& target,
388     const SmoothL1LossFuncOptions& options = {}) {
389   return detail::smooth_l1_loss(
390       input, target, options.reduction(), options.beta());
391 }
392 
393 /// See
394 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.smooth_l1_loss
395 /// about the exact behavior of this functional.
396 ///
397 /// Example:
398 /// ```
399 /// namespace F = torch::nn::functional;
400 /// F::smooth_l1_loss(input, target, /*options=*/torch::kNone, /*beta=*/0.5);
401 /// ```
smooth_l1_loss(const Tensor & input,const Tensor & target,const SmoothL1LossFuncOptions & options,double beta)402 inline Tensor smooth_l1_loss(
403     const Tensor& input,
404     const Tensor& target,
405     const SmoothL1LossFuncOptions& options,
406     double beta) {
407   TORCH_CHECK(
408       options.beta() == std::nullopt,
409       "expected beta not to be provided in 'options', but got ",
410       options.beta().value());
411   return detail::smooth_l1_loss(input, target, options.reduction(), beta);
412 }
413 
414 // ============================================================================
415 
416 #ifndef DOXYGEN_SHOULD_SKIP_THIS
417 namespace detail {
418 inline Tensor huber_loss(
419     const Tensor& input,
420     const Tensor& target,
421     HuberLossFuncOptions::reduction_t reduction,
422     double delta = 1.) {
423   if (target.sizes() != input.sizes()) {
424     TORCH_WARN(
425         "Using a target size (",
426         target.sizes(),
427         ") that is different to the input size (",
428         input.sizes(),
429         "). ",
430         "This will likely lead to incorrect results due to broadcasting. ",
431         "Please ensure they have the same size.");
432   }
433 
434   std::vector<Tensor> expanded_tensors =
435       torch::broadcast_tensors({input, target});
436   return torch::huber_loss(
437       expanded_tensors[0],
438       expanded_tensors[1],
439       enumtype::reduction_get_enum(reduction),
440       delta);
441 }
442 } // namespace detail
443 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
444 
445 /// See
446 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.huber_loss
447 /// about the exact behavior of this functional.
448 ///
449 /// See the documentation for `torch::nn::functional::HuberLossFuncOptions`
450 /// class to learn what optional arguments are supported for this functional.
451 ///
452 /// Example:
453 /// ```
454 /// namespace F = torch::nn::functional;
455 /// F::huber_loss(input, target,
456 /// F::HuberLossFuncOptions().reduction(torch::kNone).delta(0.5));
457 /// ```
458 inline Tensor huber_loss(
459     const Tensor& input,
460     const Tensor& target,
461     const HuberLossFuncOptions& options = {}) {
462   return detail::huber_loss(
463       input, target, options.reduction(), options.delta());
464 }
465 
466 // ============================================================================
467 
468 #ifndef DOXYGEN_SHOULD_SKIP_THIS
469 namespace detail {
multilabel_margin_loss(const Tensor & input,const Tensor & target,MultilabelMarginLossFuncOptions::reduction_t reduction)470 inline Tensor multilabel_margin_loss(
471     const Tensor& input,
472     const Tensor& target,
473     MultilabelMarginLossFuncOptions::reduction_t reduction) {
474   return torch::multilabel_margin_loss(
475       input, target, enumtype::reduction_get_enum(reduction));
476 }
477 } // namespace detail
478 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
479 
480 /// See
481 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.multilabel_margin_loss
482 /// about the exact behavior of this functional.
483 ///
484 /// See the documentation for
485 /// `torch::nn::functional::MultilabelMarginLossFuncOptions` class to learn what
486 /// optional arguments are supported for this functional.
487 ///
488 /// Example:
489 /// ```
490 /// namespace F = torch::nn::functional;
491 /// F::multilabel_margin_loss(input, target,
492 /// F::MultilabelMarginLossFuncOptions(torch::kNone));
493 /// ```
494 inline Tensor multilabel_margin_loss(
495     const Tensor& input,
496     const Tensor& target,
497     const MultilabelMarginLossFuncOptions& options = {}) {
498   return detail::multilabel_margin_loss(input, target, options.reduction());
499 }
500 
501 // ============================================================================
502 
503 #ifndef DOXYGEN_SHOULD_SKIP_THIS
504 namespace detail {
soft_margin_loss(const Tensor & input,const Tensor & target,SoftMarginLossFuncOptions::reduction_t reduction)505 inline Tensor soft_margin_loss(
506     const Tensor& input,
507     const Tensor& target,
508     SoftMarginLossFuncOptions::reduction_t reduction) {
509   return torch::soft_margin_loss(
510       input, target, enumtype::reduction_get_enum(reduction));
511 }
512 } // namespace detail
513 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
514 
515 /// See
516 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.soft_margin_loss
517 /// about the exact behavior of this functional.
518 ///
519 /// See the documentation for `torch::nn::functional::SoftMarginLossFuncOptions`
520 /// class to learn what optional arguments are supported for this functional.
521 ///
522 /// Example:
523 /// ```
524 /// namespace F = torch::nn::functional;
525 /// F::soft_margin_loss(input, target,
526 /// F::SoftMarginLossFuncOptions(torch::kNone));
527 /// ```
528 inline Tensor soft_margin_loss(
529     const Tensor& input,
530     const Tensor& target,
531     const SoftMarginLossFuncOptions& options = {}) {
532   return detail::soft_margin_loss(input, target, options.reduction());
533 }
534 
535 // ============================================================================
536 
537 #ifndef DOXYGEN_SHOULD_SKIP_THIS
538 namespace detail {
multilabel_soft_margin_loss(const Tensor & input,const Tensor & target,const Tensor & weight,MultilabelSoftMarginLossFuncOptions::reduction_t reduction)539 inline Tensor multilabel_soft_margin_loss(
540     const Tensor& input,
541     const Tensor& target,
542     const Tensor& weight,
543     MultilabelSoftMarginLossFuncOptions::reduction_t reduction) {
544   auto loss =
545       -(target * torch::log_sigmoid(input) +
546         (1 - target) * torch::log_sigmoid(-input));
547   if (weight.defined()) {
548     loss = loss * weight;
549   }
550 
551   auto class_dim = input.dim() - 1;
552   auto C = input.size(class_dim);
553   loss = loss.sum(class_dim) / C; // only return N loss values
554 
555   Tensor ret;
556 
557   if (std::holds_alternative<enumtype::kNone>(reduction)) {
558     ret = loss;
559   } else if (std::holds_alternative<enumtype::kMean>(reduction)) {
560     ret = loss.mean();
561   } else if (std::holds_alternative<enumtype::kSum>(reduction)) {
562     ret = loss.sum();
563   } else {
564     ret = input;
565     TORCH_INTERNAL_ASSERT(
566         false, enumtype::get_enum_name(reduction), " is not valid");
567   }
568   return ret;
569 }
570 } // namespace detail
571 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
572 
573 /// See
574 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.multilabel_soft_margin_loss
575 /// about the exact behavior of this functional.
576 ///
577 /// See the documentation for
578 /// `torch::nn::functional::MultilabelSoftMarginLossFuncOptions` class to learn
579 /// what optional arguments are supported for this functional.
580 ///
581 /// Example:
582 /// ```
583 /// namespace F = torch::nn::functional;
584 /// F::multilabel_soft_margin_loss(input, target,
585 /// F::MultilabelSoftMarginLossFuncOptions().reduction(torch::kNone).weight(weight));
586 /// ```
587 inline Tensor multilabel_soft_margin_loss(
588     const Tensor& input,
589     const Tensor& target,
590     const MultilabelSoftMarginLossFuncOptions& options = {}) {
591   return detail::multilabel_soft_margin_loss(
592       input, target, options.weight(), options.reduction());
593 }
594 
595 // ============================================================================
596 
597 #ifndef DOXYGEN_SHOULD_SKIP_THIS
598 namespace detail {
triplet_margin_loss(const Tensor & anchor,const Tensor & positive,const Tensor & negative,double margin,double p,double eps,bool swap,TripletMarginLossFuncOptions::reduction_t reduction)599 inline Tensor triplet_margin_loss(
600     const Tensor& anchor,
601     const Tensor& positive,
602     const Tensor& negative,
603     double margin,
604     double p,
605     double eps,
606     bool swap,
607     TripletMarginLossFuncOptions::reduction_t reduction) {
608   return torch::triplet_margin_loss(
609       anchor,
610       positive,
611       negative,
612       margin,
613       p,
614       eps,
615       swap,
616       enumtype::reduction_get_enum(reduction));
617 }
618 } // namespace detail
619 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
620 
621 /// See
622 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.triplet_margin_loss
623 /// about the exact behavior of this functional.
624 ///
625 /// See the documentation for
626 /// `torch::nn::functional::TripletMarginLossFuncOptions` class to learn what
627 /// optional arguments are supported for this functional.
628 ///
629 /// Example:
630 /// ```
631 /// namespace F = torch::nn::functional;
632 /// F::triplet_margin_loss(anchor, positive, negative,
633 /// F::TripletMarginLossFuncOptions().margin(1.0));
634 /// ```
635 inline Tensor triplet_margin_loss(
636     const Tensor& anchor,
637     const Tensor& positive,
638     const Tensor& negative,
639     const TripletMarginLossFuncOptions& options = {}) {
640   return detail::triplet_margin_loss(
641       anchor,
642       positive,
643       negative,
644       options.margin(),
645       options.p(),
646       options.eps(),
647       options.swap(),
648       options.reduction());
649 }
650 
651 // ============================================================================
652 
653 #ifndef DOXYGEN_SHOULD_SKIP_THIS
654 namespace detail {
triplet_margin_with_distance_loss(const Tensor & anchor,const Tensor & positive,const Tensor & negative,std::optional<TripletMarginWithDistanceLossFuncOptions::distance_function_t> distance_function,double margin,bool swap,TripletMarginWithDistanceLossFuncOptions::reduction_t reduction)655 inline Tensor triplet_margin_with_distance_loss(
656     const Tensor& anchor,
657     const Tensor& positive,
658     const Tensor& negative,
659     std::optional<TripletMarginWithDistanceLossFuncOptions::distance_function_t>
660         distance_function,
661     double margin,
662     bool swap,
663     TripletMarginWithDistanceLossFuncOptions::reduction_t reduction) {
664   Tensor dist_pos, dist_neg;
665   if (distance_function.has_value()) {
666     auto distance_function_impl = distance_function.value();
667     dist_pos = distance_function_impl(anchor, positive);
668     dist_neg = distance_function_impl(anchor, negative);
669   } else {
670     dist_pos = pairwise_distance(anchor, positive);
671     dist_neg = pairwise_distance(anchor, negative);
672   }
673 
674   if (swap) {
675     Tensor dist_swap;
676     if (distance_function.has_value()) {
677       dist_swap = distance_function.value()(positive, negative);
678     } else {
679       dist_swap = pairwise_distance(positive, negative);
680     }
681     dist_neg = torch::min(dist_neg, dist_swap);
682   }
683 
684   auto loss = torch::clamp_min(dist_pos - dist_neg + margin, 0);
685 
686   Tensor ret;
687   if (std::holds_alternative<enumtype::kNone>(reduction)) {
688     ret = loss;
689   } else if (std::holds_alternative<enumtype::kMean>(reduction)) {
690     ret = loss.mean();
691   } else if (std::holds_alternative<enumtype::kSum>(reduction)) {
692     ret = loss.sum();
693   } else {
694     ret = anchor;
695     TORCH_INTERNAL_ASSERT(
696         false, enumtype::get_enum_name(reduction), " is not valid");
697   }
698   return ret;
699 }
700 } // namespace detail
701 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
702 
703 /// See
704 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.triplet_margin_with_distance_loss
705 /// about the exact behavior of this functional.
706 ///
707 /// See the documentation for
708 /// `torch::nn::functional::TripletMarginWithDistanceLossFuncOptions` class to
709 /// learn what optional arguments are supported for this functional.
710 ///
711 /// Example:
712 /// ```
713 /// namespace F = torch::nn::functional;
714 /// F::triplet_margin_with_distance_loss(anchor, positive, negative,
715 /// F::TripletMarginWithDistanceLossFuncOptions().margin(1.0));
716 /// ```
717 inline Tensor triplet_margin_with_distance_loss(
718     const Tensor& anchor,
719     const Tensor& positive,
720     const Tensor& negative,
721     const TripletMarginWithDistanceLossFuncOptions& options = {}) {
722   return detail::triplet_margin_with_distance_loss(
723       anchor,
724       positive,
725       negative,
726       options.distance_function(),
727       options.margin(),
728       options.swap(),
729       options.reduction());
730 }
731 
732 // ============================================================================
733 
734 #ifndef DOXYGEN_SHOULD_SKIP_THIS
735 namespace detail {
ctc_loss(const Tensor & log_probs,const Tensor & targets,const Tensor & input_lengths,const Tensor & target_lengths,int64_t blank,CTCLossFuncOptions::reduction_t reduction,bool zero_infinity)736 inline Tensor ctc_loss(
737     const Tensor& log_probs,
738     const Tensor& targets,
739     const Tensor& input_lengths,
740     const Tensor& target_lengths,
741     int64_t blank,
742     CTCLossFuncOptions::reduction_t reduction,
743     bool zero_infinity) {
744   return torch::ctc_loss(
745       log_probs,
746       targets,
747       input_lengths,
748       target_lengths,
749       blank,
750       enumtype::reduction_get_enum(reduction),
751       zero_infinity);
752 }
753 } // namespace detail
754 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
755 
756 /// See
757 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.ctc_loss
758 /// about the exact behavior of this functional.
759 ///
760 /// See the documentation for `torch::nn::functional::CTCLossFuncOptions` class
761 /// to learn what optional arguments are supported for this functional.
762 ///
763 /// Example:
764 /// ```
765 /// namespace F = torch::nn::functional;
766 /// F::ctc_loss(log_probs, targets, input_lengths, target_lengths,
767 /// F::CTCLossFuncOptions().reduction(torch::kNone));
768 /// ```
769 inline Tensor ctc_loss(
770     const Tensor& log_probs,
771     const Tensor& targets,
772     const Tensor& input_lengths,
773     const Tensor& target_lengths,
774     const CTCLossFuncOptions& options = {}) {
775   return detail::ctc_loss(
776       log_probs,
777       targets,
778       input_lengths,
779       target_lengths,
780       options.blank(),
781       options.reduction(),
782       options.zero_infinity());
783 }
784 
785 // ============================================================================
786 
787 #ifndef DOXYGEN_SHOULD_SKIP_THIS
788 namespace detail {
poisson_nll_loss(const Tensor & input,const Tensor & target,bool log_input,bool full,double eps,PoissonNLLLossFuncOptions::reduction_t reduction)789 inline Tensor poisson_nll_loss(
790     const Tensor& input,
791     const Tensor& target,
792     bool log_input,
793     bool full,
794     double eps,
795     PoissonNLLLossFuncOptions::reduction_t reduction) {
796   return torch::poisson_nll_loss(
797       input,
798       target,
799       log_input,
800       full,
801       eps,
802       enumtype::reduction_get_enum(reduction));
803 }
804 } // namespace detail
805 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
806 
807 /// See
808 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.poisson_nll_loss
809 /// about the exact behavior of this functional.
810 ///
811 /// See the documentation for `torch::nn::functional::PoissonNLLLossFuncOptions`
812 /// class to learn what optional arguments are supported for this functional.
813 ///
814 /// Example:
815 /// ```
816 /// namespace F = torch::nn::functional;
817 /// F::poisson_nll_loss(input, target,
818 /// F::PoissonNLLLossFuncOptions().reduction(torch::kNone));
819 /// ```
820 inline Tensor poisson_nll_loss(
821     const Tensor& input,
822     const Tensor& target,
823     const PoissonNLLLossFuncOptions& options = {}) {
824   return detail::poisson_nll_loss(
825       input,
826       target,
827       options.log_input(),
828       options.full(),
829       options.eps(),
830       options.reduction());
831 }
832 
833 // ============================================================================
834 
835 #ifndef DOXYGEN_SHOULD_SKIP_THIS
836 namespace detail {
margin_ranking_loss(const Tensor & input1,const Tensor & input2,const Tensor & target,double margin,MarginRankingLossFuncOptions::reduction_t reduction)837 inline Tensor margin_ranking_loss(
838     const Tensor& input1,
839     const Tensor& input2,
840     const Tensor& target,
841     double margin,
842     MarginRankingLossFuncOptions::reduction_t reduction) {
843   TORCH_CHECK(
844       input1.dim() == input2.dim() && input1.dim() == target.dim(),
845       "margin_ranking_loss : All input tensors should have same dimension but got sizes: "
846       "input1: ",
847       input1.sizes(),
848       ", input2: ",
849       input2.sizes(),
850       ", target: ",
851       target.sizes());
852   return torch::margin_ranking_loss(
853       input1, input2, target, margin, enumtype::reduction_get_enum(reduction));
854 }
855 } // namespace detail
856 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
857 
858 /// See
859 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.margin_ranking_loss
860 /// about the exact behavior of this functional.
861 ///
862 /// See the documentation for
863 /// `torch::nn::functional::MarginRankingLossFuncOptions` class to learn what
864 /// optional arguments are supported for this functional.
865 ///
866 /// Example:
867 /// ```
868 /// namespace F = torch::nn::functional;
869 /// F::margin_ranking_loss(input1, input2, target,
870 /// F::MarginRankingLossFuncOptions().margin(0.5).reduction(torch::kSum));
871 /// ```
872 inline Tensor margin_ranking_loss(
873     const Tensor& input1,
874     const Tensor& input2,
875     const Tensor& target,
876     const MarginRankingLossFuncOptions& options = {}) {
877   return detail::margin_ranking_loss(
878       input1, input2, target, options.margin(), options.reduction());
879 }
880 
881 // ============================================================================
882 
883 #ifndef DOXYGEN_SHOULD_SKIP_THIS
884 namespace detail {
nll_loss(const Tensor & input,const Tensor & target,const Tensor & weight,int64_t ignore_index,const NLLLossFuncOptions::reduction_t & reduction)885 inline Tensor nll_loss(
886     const Tensor& input,
887     const Tensor& target,
888     const Tensor& weight,
889     int64_t ignore_index,
890     const NLLLossFuncOptions::reduction_t& reduction) {
891   if (input.dim() < 2) {
892     TORCH_CHECK(false, "Expected 2 or more dimensions (got ", input.dim(), ")");
893   }
894 
895   if (input.sizes()[0] != target.sizes()[0]) {
896     TORCH_CHECK(
897         false,
898         "Expected input batch_size (",
899         input.sizes()[0],
900         ") to match target batch_size (",
901         target.sizes()[0],
902         ").");
903   }
904 
905   return torch::nll_loss_nd(
906       input,
907       target,
908       weight,
909       enumtype::reduction_get_enum(reduction),
910       ignore_index);
911 }
912 } // namespace detail
913 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
914 
915 /// See
916 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.nll_loss
917 /// about the exact behavior of this functional.
918 ///
919 /// See the documentation for `torch::nn::functional::NLLLossFuncOptions` class
920 /// to learn what optional arguments are supported for this functional.
921 ///
922 /// Example:
923 /// ```
924 /// namespace F = torch::nn::functional;
925 /// F::nll_loss(input, target,
926 /// F::NLLLossFuncOptions().ignore_index(-100).reduction(torch::kMean));
927 /// ```
928 inline Tensor nll_loss(
929     const Tensor& input,
930     const Tensor& target,
931     const NLLLossFuncOptions& options = {}) {
932   return detail::nll_loss(
933       input,
934       target,
935       options.weight(),
936       options.ignore_index(),
937       options.reduction());
938 }
939 
940 // ============================================================================
941 
942 #ifndef DOXYGEN_SHOULD_SKIP_THIS
943 namespace detail {
cross_entropy(const Tensor & input,const Tensor & target,const Tensor & weight,int64_t ignore_index,CrossEntropyFuncOptions::reduction_t reduction,double label_smoothing)944 inline Tensor cross_entropy(
945     const Tensor& input,
946     const Tensor& target,
947     const Tensor& weight,
948     int64_t ignore_index,
949     CrossEntropyFuncOptions::reduction_t reduction,
950     double label_smoothing) {
951   return torch::cross_entropy_loss(
952       input,
953       target,
954       weight,
955       enumtype::reduction_get_enum(reduction),
956       ignore_index,
957       label_smoothing);
958 }
959 } // namespace detail
960 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
961 
962 /// See
963 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.cross_entropy
964 /// about the exact behavior of this functional.
965 ///
966 /// See the documentation for `torch::nn::functional::CrossEntropyFuncOptions`
967 /// class to learn what optional arguments are supported for this functional.
968 ///
969 /// Example:
970 /// ```
971 /// namespace F = torch::nn::functional;
972 /// F::cross_entropy(input, target,
973 /// F::CrossEntropyFuncOptions().ignore_index(-100).reduction(torch::kMean));
974 /// ```
975 inline Tensor cross_entropy(
976     const Tensor& input,
977     const Tensor& target,
978     const CrossEntropyFuncOptions& options = {}) {
979   return detail::cross_entropy(
980       input,
981       target,
982       options.weight(),
983       options.ignore_index(),
984       options.reduction(),
985       options.label_smoothing());
986 }
987 
988 // ============================================================================
989 
990 #ifndef DOXYGEN_SHOULD_SKIP_THIS
991 namespace detail {
binary_cross_entropy_with_logits(const Tensor & input,const Tensor & target,const Tensor & weight,BinaryCrossEntropyWithLogitsFuncOptions::reduction_t reduction,const Tensor & pos_weight)992 inline Tensor binary_cross_entropy_with_logits(
993     const Tensor& input,
994     const Tensor& target,
995     const Tensor& weight,
996     BinaryCrossEntropyWithLogitsFuncOptions::reduction_t reduction,
997     const Tensor& pos_weight) {
998   TORCH_CHECK(
999       target.sizes() == input.sizes(),
1000       "Target size (",
1001       target.sizes(),
1002       ") must be the same as input size (",
1003       input.sizes(),
1004       ")");
1005 
1006   return torch::binary_cross_entropy_with_logits(
1007       input,
1008       target,
1009       weight,
1010       pos_weight,
1011       enumtype::reduction_get_enum(reduction));
1012 }
1013 } // namespace detail
1014 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
1015 
1016 /// See
1017 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.binary_cross_entropy_with_logits
1018 /// about the exact behavior of this functional.
1019 ///
1020 /// See the documentation for
1021 /// `torch::nn::functional::BinaryCrossEntropyWithLogitsFuncOptions` class to
1022 /// learn what optional arguments are supported for this functional.
1023 ///
1024 /// Example:
1025 /// ```
1026 /// namespace F = torch::nn::functional;
1027 /// F::binary_cross_entropy_with_logits(input, target,
1028 /// F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight).reduction(torch::kSum));
1029 /// ```
1030 inline Tensor binary_cross_entropy_with_logits(
1031     const Tensor& input,
1032     const Tensor& target,
1033     const BinaryCrossEntropyWithLogitsFuncOptions& options = {}) {
1034   return detail::binary_cross_entropy_with_logits(
1035       input,
1036       target,
1037       options.weight(),
1038       options.reduction(),
1039       options.pos_weight());
1040 }
1041 
1042 } // namespace functional
1043 } // namespace nn
1044 } // namespace torch
1045