xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/LossMultiLabelMargin.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/TensorUtils.h>
6 #include <ATen/native/LossMulti.h>
7 #include <c10/util/irange.h>
8 
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #include <ATen/NativeFunctions.h>
12 #else
13 #include <ATen/ops/empty.h>
14 #include <ATen/ops/multilabel_margin_loss_backward_native.h>
15 #include <ATen/ops/multilabel_margin_loss_forward.h>
16 #include <ATen/ops/multilabel_margin_loss_forward_native.h>
17 #include <ATen/ops/multilabel_margin_loss_native.h>
18 #include <ATen/ops/zeros_like.h>
19 #endif
20 
21 namespace at::native {
22 
23 namespace {
24 
25 template <typename scalar_t>
multilabel_margin_loss_forward_inner_sum_cpu(const scalar_t * input_data,const int64_t * target_data,scalar_t * is_target_data,int64_t dim)26 inline scalar_t multilabel_margin_loss_forward_inner_sum_cpu(
27     const scalar_t* input_data,
28     const int64_t* target_data,
29     scalar_t* is_target_data,
30     int64_t dim) {
31   using accscalar_t = at::acc_type<scalar_t, false>;
32   accscalar_t sum = 0;
33   for (const auto ddt : c10::irange(dim)) {
34     int64_t target_idx = target_data[ddt];
35     if (target_idx < 0) {
36       break;
37     }
38     is_target_data[target_idx] = 1;
39   }
40   for (const auto dt : c10::irange(dim)) {
41     int64_t target_idx = target_data[dt];
42     if (target_idx < 0) {
43       break;
44     }
45 
46     scalar_t input_target = input_data[target_idx];
47     for (const auto d : c10::irange(dim)) {
48       if (!is_target_data[d]) {
49         scalar_t z = 1 - input_target + input_data[d];
50         if (z > 0) {
51           sum += z;
52         }
53       }
54     }
55   }
56 
57   return sum;
58 }
59 
60 template <typename scalar_t>
multilabel_margin_loss_forward_out_frame(const Tensor & input_contiguous,const Tensor & target_contiguous,Tensor & output,Tensor & is_target,int64_t reduction,int64_t nframe,int64_t dim)61 static void multilabel_margin_loss_forward_out_frame(
62     const Tensor& input_contiguous,
63     const Tensor& target_contiguous,
64     Tensor& output,
65     Tensor& is_target,
66     int64_t reduction,
67     int64_t nframe,
68     int64_t dim) {
69   using accscalar_t = at::acc_type<scalar_t, false>;
70   const scalar_t* input_data = input_contiguous.const_data_ptr<scalar_t>();
71   const int64_t* target_data = target_contiguous.const_data_ptr<int64_t>();
72   scalar_t* is_target_data = is_target.data_ptr<scalar_t>();
73 
74   if (reduction != Reduction::None || output.dim() == 0) {
75     scalar_t* output_data = output.data_ptr<scalar_t>();
76 
77     accscalar_t sum = 0;
78 
79     for (C10_UNUSED const auto t : c10::irange(nframe)) {
80       sum += multilabel_margin_loss_forward_inner_sum_cpu(
81           input_data, target_data, is_target_data, dim);
82 
83       input_data += dim;
84       target_data += dim;
85       is_target_data += dim;
86     }
87 
88     sum /= dim;
89     if (reduction == Reduction::Mean) {
90       sum /= nframe;
91     }
92 
93     *output_data = sum; // write scalar output value
94   } else {
95     auto output_acc = output.accessor<scalar_t, 1>();
96 
97     for (const auto t : c10::irange(nframe)) {
98       scalar_t sum = multilabel_margin_loss_forward_inner_sum_cpu(
99           input_data, target_data, is_target_data, dim);
100 
101       sum /= dim;
102       output_acc[t] = sum;
103 
104       input_data += dim;
105       target_data += dim;
106       is_target_data += dim;
107     }
108   }
109 }
110 
multilabel_margin_loss_forward_out_cpu_template(const Tensor & input,const Tensor & target,Tensor & output,Tensor & is_target,int64_t reduction)111 static void multilabel_margin_loss_forward_out_cpu_template(
112     const Tensor& input,
113     const Tensor& target,
114     Tensor& output,
115     Tensor& is_target,
116     int64_t reduction) {
117 #ifndef STRIP_ERROR_MESSAGES
118   auto target_arg = TensorArg(target, "target", 2);
119 #endif
120   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
121   int64_t nframe, dim;
122   const int64_t ndims = input.dim();
123   multilabel_margin_loss_shape_check(nframe, dim, ndims, input, target);
124 
125   // special case target.dim() <= 1: produce scalar output for scalar inputs
126   // even if reduction == Reduction::None
127   if (reduction != Reduction::None || target.dim() <= 1) {
128     output.resize_({});
129   } else {
130     output.resize_({nframe});
131   }
132 
133   is_target.resize_as_(target);
134   TORCH_CHECK(is_target.is_contiguous(), "is_target must be contiguous");
135   is_target.zero_();
136 
137   if (input.numel() == 0) {
138     return;
139   }
140 
141   TORCH_CHECK(
142       target.min().item<int64_t>() >= -1, target_arg, " is out of range");
143   TORCH_CHECK(
144       target.max().item<int64_t>() < dim, target_arg, " is out of range");
145 
146   auto input_contiguous = input.contiguous();
147   auto target_contiguous = target.contiguous();
148 
149   AT_DISPATCH_FLOATING_TYPES(
150       input.scalar_type(), "multilabel_margin_loss_forward_out_frame", [&] {
151         multilabel_margin_loss_forward_out_frame<scalar_t>(
152             input_contiguous, target_contiguous, output, is_target, reduction, nframe, dim);
153       });
154 }
155 
156 template <typename scalar_t>
multilabel_margin_loss_backward_out_frame(Tensor & grad_input,const Tensor & grad_output,const Tensor & input_contiguous,const Tensor & target_contiguous,int64_t reduction,const Tensor & is_target_contiguous,int64_t nframe,int64_t dim)157 static void multilabel_margin_loss_backward_out_frame(
158     Tensor& grad_input,
159     const Tensor& grad_output,
160     const Tensor& input_contiguous,
161     const Tensor& target_contiguous,
162     int64_t reduction,
163     const Tensor& is_target_contiguous,
164     int64_t nframe,
165     int64_t dim) {
166 #ifndef STRIP_ERROR_MESSAGES
167   auto is_target_arg = TensorArg(is_target_contiguous, "is_target", 5);
168 #endif
169 
170   TORCH_CHECK(
171       is_target_contiguous.min().item<scalar_t>() >= 0, is_target_arg, " is out of range");
172   TORCH_CHECK(
173       is_target_contiguous.max().item<scalar_t>() <= 1, is_target_arg, " is out of range");
174 
175   const scalar_t* input_data = input_contiguous.const_data_ptr<scalar_t>();
176   const int64_t* target_data = target_contiguous.const_data_ptr<int64_t>();
177   const scalar_t* is_target_data = is_target_contiguous.const_data_ptr<scalar_t>();
178   scalar_t g = static_cast<scalar_t>(
179       // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
180       reduction == Reduction::Mean ? 1. / (nframe * dim) : 1. / dim);
181 
182   scalar_t* grad_input_row_data = grad_input.mutable_data_ptr<scalar_t>();
183   for (C10_UNUSED const auto t : c10::irange(nframe)) {
184     for (const auto dt : c10::irange(dim)) {
185       int64_t target_idx = target_data[dt];
186       if (target_idx < 0) {
187         break;
188       }
189 
190       scalar_t input_target = input_data[target_idx];
191       for (const auto d : c10::irange(dim)) {
192         if (!is_target_data[d]) {
193           scalar_t z = 1 - input_target + input_data[d];
194           if (z > 0) {
195             grad_input_row_data[target_idx] -= g;
196             grad_input_row_data[d] += g;
197           }
198         }
199       }
200     }
201     input_data += dim;
202     target_data += dim;
203     is_target_data += dim;
204     grad_input_row_data += dim;
205   }
206 
207   scalar_t* grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
208   if (reduction != Reduction::None || grad_output.dim() == 0) {
209     assert(
210         reduction != Reduction::None || grad_output.dim() > 0 || nframe == 1);
211     const auto d = *grad_output.const_data_ptr<scalar_t>();
212     for (int64_t t = 0; t < nframe * dim; t++) {
213       grad_input_data[t] *= d;
214     }
215   } else {
216     check_dim_size(grad_output, 1, 0, nframe);
217     auto grad_output_acc = grad_output.accessor<const scalar_t, 1>();
218     for (const auto t : c10::irange(nframe)) {
219       for (const auto d : c10::irange(dim)) {
220         grad_input_data[t * dim + d] *= grad_output_acc[t];
221       }
222     }
223   }
224 }
225 
multilabel_margin_loss_backward_out_cpu_template(Tensor & grad_input,const Tensor & grad_output,const Tensor & input,const Tensor & target,int64_t reduction,const Tensor & is_target)226 static void multilabel_margin_loss_backward_out_cpu_template(
227     Tensor& grad_input,
228     const Tensor& grad_output,
229     const Tensor& input,
230     const Tensor& target,
231     int64_t reduction,
232     const Tensor& is_target) {
233   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
234   int64_t nframe, dim;
235   CheckedFrom c = "multilabel_margin_loss_backward_cpu_template";
236   auto target_arg = TensorArg(target, "target", 3);
237   auto is_target_arg = TensorArg(is_target, "is_target", 5);
238   const int64_t ndims = input.dim();
239 
240   multilabel_margin_loss_shape_check(nframe, dim, ndims, input, target);
241   checkSameSize(c, target_arg, is_target_arg);
242 
243   grad_input.resize_as_(input);
244   if (grad_input.numel() == 0) {
245     return;
246   }
247 
248   TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous");
249   grad_input.zero_();
250 
251   TORCH_CHECK(
252       target.min().item<int64_t>() >= -1, target_arg, " is out of range");
253   TORCH_CHECK(
254       target.max().item<int64_t>() < dim, target_arg, " is out of range");
255 
256   auto input_contiguous = input.contiguous();
257   auto target_contiguous = target.contiguous();
258   auto is_target_contiguous = is_target.contiguous();
259 
260   AT_DISPATCH_FLOATING_TYPES(
261       input.scalar_type(), "multilabel_margin_loss_backward_out_frame", [&] {
262         multilabel_margin_loss_backward_out_frame<scalar_t>(
263             grad_input,
264             grad_output,
265             input_contiguous,
266             target_contiguous,
267             reduction,
268             is_target_contiguous,
269             nframe,
270             dim);
271       });
272 }
273 
274 } // namespace
275 
multilabel_margin_loss_forward_out_cpu(const Tensor & self,const Tensor & target,int64_t reduction,Tensor & output,Tensor & is_target)276 std::tuple<Tensor&, Tensor&> multilabel_margin_loss_forward_out_cpu(const Tensor& self,
277     const Tensor& target,
278     int64_t reduction,
279     Tensor& output,
280     Tensor& is_target) {
281   multilabel_margin_loss_forward_out_cpu_template(
282       self, target, output, is_target, reduction);
283   return std::tuple<Tensor&, Tensor&>(output, is_target);
284 }
285 
multilabel_margin_loss_forward_cpu(const Tensor & self,const Tensor & target,int64_t reduction)286 std::tuple<Tensor, Tensor> multilabel_margin_loss_forward_cpu(
287     const Tensor& self,
288     const Tensor& target,
289     int64_t reduction) {
290   auto output = at::empty({0}, self.options());
291   auto is_target = at::empty({0}, self.options());
292   at::native::multilabel_margin_loss_forward_out_cpu(
293       self, target, reduction, output, is_target);
294   return std::make_tuple(output, is_target);
295 }
296 
multilabel_margin_loss_backward_cpu_out(const Tensor & grad_output,const Tensor & self,const Tensor & target,int64_t reduction,const Tensor & is_target,Tensor & grad_input)297 Tensor& multilabel_margin_loss_backward_cpu_out(const Tensor& grad_output,
298     const Tensor& self,
299     const Tensor& target,
300     int64_t reduction,
301     const Tensor& is_target,
302     Tensor& grad_input) {
303   multilabel_margin_loss_backward_out_cpu_template(
304       grad_input, grad_output, self, target, reduction, is_target);
305   return grad_input;
306 }
307 
multilabel_margin_loss_backward_cpu(const Tensor & grad_output,const Tensor & self,const Tensor & target,int64_t reduction,const Tensor & is_target)308 Tensor multilabel_margin_loss_backward_cpu(
309     const Tensor& grad_output,
310     const Tensor& self,
311     const Tensor& target,
312     int64_t reduction,
313     const Tensor& is_target) {
314   auto grad_input = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
315   at::native::multilabel_margin_loss_backward_cpu_out(
316       grad_output, self, target, reduction, is_target, grad_input);
317   return grad_input;
318 }
319 
multilabel_margin_loss_out(const Tensor & self,const Tensor & target,int64_t reduction,Tensor & output)320 Tensor & multilabel_margin_loss_out(const Tensor & self, const Tensor & target, int64_t reduction, Tensor & output) {
321   Tensor is_target = at::empty({0}, self.options());
322   return std::get<0>(at::multilabel_margin_loss_forward_out(output, is_target, self, target, reduction));
323 }
324 
multilabel_margin_loss(const Tensor & self,const Tensor & target,int64_t reduction)325 Tensor multilabel_margin_loss(const Tensor & self, const Tensor & target, int64_t reduction) {
326   return std::get<0>(at::multilabel_margin_loss_forward(self, target, reduction));
327 }
328 
329 } // namespace at::native
330