xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/LossMulti.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/TensorUtils.h>
6 
7 namespace at::native {
multilabel_margin_loss_shape_check(int64_t & nframe,int64_t & dim,const int64_t & ndims,const Tensor & input,const Tensor & target)8   inline void multilabel_margin_loss_shape_check(
9     int64_t& nframe,
10     int64_t& dim,
11     const int64_t& ndims,
12     const Tensor& input,
13     const Tensor& target) {
14     TORCH_CHECK(
15         (ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0,
16         "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
17         input.sizes());
18 
19     if (ndims <= 1) {
20       nframe = 1;
21       dim = ndims == 0 ? 1 : input.size(0);
22       TORCH_CHECK(
23           target.dim() <= 1 && target.numel() == dim,
24           "inconsistent target size: ", target.sizes(), " for input of size: ",
25           input.sizes());
26     } else {
27       nframe = input.size(0);
28       dim = input.size(1);
29       TORCH_CHECK(
30           target.dim() == 2 && target.size(0) == nframe &&
31           target.size(1) == dim,
32           "inconsistent target size: ", target.sizes(), " for input of size: ",
33           input.sizes());
34     }
35   }
36 
multi_margin_loss_shape_check(int64_t & nframe,int64_t & dim,const int64_t & ndims,const Tensor & input,const Tensor & target,const std::optional<Tensor> & weight)37   inline void multi_margin_loss_shape_check(
38     int64_t& nframe,
39     int64_t& dim,
40     const int64_t& ndims,
41     const Tensor& input,
42     const Tensor& target,
43     const std::optional<Tensor>& weight) {
44     TORCH_CHECK(
45         (ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0,
46         "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
47         input.sizes());
48 
49     if (ndims <= 1) {
50       nframe = 1;
51       dim = ndims == 0 ? 1 : input.size(0);
52     } else {
53       nframe = input.size(0);
54       dim = input.size(1);
55     }
56 
57     TORCH_CHECK(
58         target.dim() <= 1 && target.numel() == nframe,
59         "inconsistent target size, expected ", nframe, " but got ",
60         target.sizes());
61     if (weight && weight->defined()) {
62       TORCH_CHECK(
63           weight->dim() <= 1 && weight->numel() == dim,
64           "inconsistent weight size, expected ", dim, " but got ",
65           weight->sizes());
66     }
67 }
68 
69 } // namespace at::native
70