xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/LossNLL2d.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/native/cpu/utils.h>
6 #include <ATen/native/Resize.h>
7 #include <c10/util/irange.h>
8 
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #include <ATen/NativeFunctions.h>
12 #else
13 #include <ATen/ops/empty.h>
14 #include <ATen/ops/nll_loss2d_backward_native.h>
15 #include <ATen/ops/nll_loss2d_forward.h>
16 #include <ATen/ops/nll_loss2d_forward_native.h>
17 #include <ATen/ops/nll_loss2d_native.h>
18 #include <ATen/ops/zeros_like.h>
19 
20 #include <utility>
21 #endif
22 
23 namespace at::native {
24 
25 namespace {
26 
27 // Returns a contiguous tensor if the source tensor
28 // is defined. Otherwise returns the undefined
29 // source tensor unmodified.
optional_contiguous(const Tensor & source)30 inline Tensor optional_contiguous(const Tensor& source) {
31   return source.defined() ? source.contiguous() : source;
32 }
33 
34 // Returns the address of the first element of a tensor
35 // or nullptr if the tensor is undefined.
36 template <typename scalar_t>
optional_data(const Tensor & source)37 inline scalar_t* optional_data(const Tensor& source) {
38   if constexpr (std::is_const<scalar_t>::value) {
39     return source.defined() ? source.const_data_ptr<scalar_t>() : nullptr;
40   } else {
41     return source.defined() ? source.data_ptr<scalar_t>() : nullptr;
42   }
43 }
44 
check_inputs_nll_loss2d(const Tensor & input,const Tensor & target,const Tensor & weight)45 inline void check_inputs_nll_loss2d(
46     const Tensor& input,
47     const Tensor& target,
48     const Tensor& weight) {
49   TORCH_CHECK(
50       target.dim() == 3,
51       "only batches of spatial targets supported (3D tensors)"
52       " but got targets of dimension: ",
53       target.dim());
54   TORCH_CHECK(
55       input.dim() == 4,
56       "only batches of spatial inputs supported (4D tensors), "
57       "but got input of dimension: ",
58       input.dim());
59   TORCH_CHECK(
60       !weight.defined() || weight.numel() == input.size(1),
61       "weight tensor should be defined either for all or no classes");
62 
63   const int64_t input0 = input.size(0);
64   const int64_t input2 = input.size(2);
65   const int64_t input3 = input.size(3);
66   const int64_t target0 = target.size(0);
67   const int64_t target1 = target.size(1);
68   const int64_t target2 = target.size(2);
69   TORCH_CHECK(
70       input0 == target0 && input2 == target1 && input3 == target2,
71       "size mismatch (got input: ",
72       input.sizes(),
73       " , target: ",
74       target.sizes());
75 }
76 
check_gradout_shape_nll_loss2d(const Tensor & grad_output,const Tensor & target)77 inline void check_gradout_shape_nll_loss2d(
78     const Tensor& grad_output,
79     const Tensor& target) {
80   TORCH_CHECK(
81       grad_output.dim() == 3,
82       "grad_output must have same dimension as target (3) but got dimension: ",
83       grad_output.sizes());
84 
85   const int64_t grad_output0 = grad_output.size(0);
86   const int64_t grad_output1 = grad_output.size(1);
87   const int64_t grad_output2 = grad_output.size(2);
88   const int64_t target0 = target.size(0);
89   const int64_t target1 = target.size(1);
90   const int64_t target2 = target.size(2);
91   TORCH_CHECK(
92       grad_output0 == target0 && grad_output1 == target1 &&
93           grad_output2 == target2,
94       "size mismatch (got grad_output: ",
95       grad_output.sizes(),
96       " target: ",
97       target.sizes());
98 }
99 
100 
101 template <typename scalar_t>
nll_loss2d_forward_out_frame(Tensor & output,Tensor & total_weight,const Tensor & input,const Tensor & target,const Tensor & weight,int64_t reduction,int64_t ignore_index)102 static void nll_loss2d_forward_out_frame(
103     Tensor& output,
104     Tensor& total_weight,
105     const Tensor& input,
106     const Tensor& target,
107     const Tensor& weight,
108     int64_t reduction,
109     int64_t ignore_index) {
110   const int64_t n_classes = input.size(1);
111 
112   scalar_t* total_weight_data = total_weight.data_ptr<scalar_t>();
113   *total_weight_data = 0;
114 
115   auto weight_contiguous = optional_contiguous(weight);
116   const scalar_t* weight_data = optional_data<const scalar_t>(weight_contiguous);
117 
118   if (reduction == Reduction::None) {
119     const int64_t batch_size = input.size(0);
120     const int64_t H = input.size(2);
121     const int64_t W = input.size(3);
122 
123     at::native::resize_output(output, {batch_size, H, W});
124     auto input_acc = input.accessor<const scalar_t, 4>();
125     auto output_acc = output.accessor<scalar_t, 3>();
126     auto target_acc = target.accessor<const int64_t, 3>();
127 
128     at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
129       for (const auto b : c10::irange(start, end)) {
130         for (const auto h : c10::irange(H)) {
131           for (const auto w : c10::irange(W)) {
132             const int64_t cur_target = (int64_t)target_acc[b][h][w];
133 
134             if (cur_target == ignore_index) {
135               output_acc[b][h][w] = static_cast<scalar_t>(0);
136               continue;
137             }
138 
139             TORCH_CHECK_INDEX(
140                 cur_target >= 0 && cur_target < n_classes,
141                 "Target ",
142                 cur_target,
143                 " is out of bounds.");
144 
145             // load optional weight value
146             const scalar_t cur_weight = weight_data != nullptr
147                 ? weight_data[cur_target]
148                 : static_cast<scalar_t>(1);
149             output_acc[b][h][w] = -input_acc[b][cur_target][h][w] * cur_weight;
150           }
151         }
152       }
153     });
154 
155     return;
156   }
157 
158   // produce scalar outputs for the reduction case
159   at::native::resize_output(output, {});
160 
161   if (target.numel() == 0) {
162     // Here target (and input) have zero elements
163     // Mean reduction on empty tensors produces NaN. See the discussion in
164     // https://github.com/pytorch/pytorch/pull/64572#issuecomment-926504162
165     if (reduction == Reduction::Mean) {
166       output.fill_(std::numeric_limits<double>::quiet_NaN());
167     } else {
168       output.zero_();
169     }
170     total_weight.zero_();
171     return;
172   }
173 
174   auto input_contiguous = input.contiguous();
175   auto target_contiguous = target.contiguous();
176 
177   const scalar_t* input_data = input_contiguous.const_data_ptr<scalar_t>();
178   const int64_t* target_data = target_contiguous.const_data_ptr<int64_t>();
179 
180   const int64_t batch_size = input.size(0);
181   const int64_t map_size = input.size(2) * input.size(3);
182   const int64_t sample_size = map_size * n_classes;
183   const int64_t numiter = batch_size * map_size;
184 
185   constexpr int64_t cascade_sum_num_levels = 8;
186   // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
187   scalar_t weight_partial_sums[cascade_sum_num_levels] = {0};
188   // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
189   scalar_t loss_partial_sums[cascade_sum_num_levels] = {0};
190   const int64_t level_power =
191       std::max(int64_t(4), utils::CeilLog2(numiter) / cascade_sum_num_levels);
192   const int64_t level_step = (1 << level_power);
193   const int64_t level_mask = level_step - 1;
194 
195   int64_t num_ignored = 0;
196   for (const auto b : c10::irange(batch_size)) {
197     for (const auto elem : c10::irange(map_size)) {
198       const int64_t cur_target = target_data[b * map_size + elem];
199       if (cur_target == ignore_index) {
200         ++num_ignored;
201         continue;
202       }
203 
204       TORCH_CHECK_INDEX(
205           cur_target >= 0 && cur_target < n_classes,
206           "Target ",
207           cur_target,
208           " is out of bounds.");
209 
210       const auto data = input_data[b * sample_size + cur_target * map_size + elem];
211       if (weight_data) {
212         const scalar_t weight_val = weight_data[cur_target];
213         loss_partial_sums[0] -= data * weight_val;
214         weight_partial_sums[0] += weight_val;
215       } else {
216         loss_partial_sums[0] -= data;
217       }
218 
219       const int64_t linear_idx = b * map_size + elem;
220       for (int64_t j = 0; j + 1 < cascade_sum_num_levels; ++j) {
221         const auto mask = (level_mask << (j * level_power));
222         if (C10_LIKELY((linear_idx & mask) != 0)) {
223           break;
224         }
225 
226         weight_partial_sums[j + 1] += weight_partial_sums[j];
227         loss_partial_sums[j + 1] += loss_partial_sums[j];
228 
229         weight_partial_sums[j] = 0;
230         loss_partial_sums[j] = 0;
231       }
232     }
233   }
234 
235 
236   const scalar_t total_weight_val = !weight_data ?
237     static_cast<scalar_t>(numiter - num_ignored) :
238     std::accumulate(std::begin(weight_partial_sums),
239                     std::end(weight_partial_sums),
240                     scalar_t{0});
241 
242   scalar_t output_val = std::accumulate(std::begin(loss_partial_sums),
243                                         std::end(loss_partial_sums),
244                                         scalar_t{0});
245 
246   if (reduction == Reduction::Mean) {
247     output_val /= total_weight_val;
248   }
249 
250   *total_weight_data = total_weight_val;
251   *output.data_ptr<scalar_t>() = output_val;
252 }
253 
nll_loss2d_forward_out_cpu_template(Tensor & output,Tensor & total_weight,const Tensor & input,const Tensor & target,const Tensor & weight,int64_t reduction,int64_t ignore_index)254 void nll_loss2d_forward_out_cpu_template(
255     Tensor& output,
256     Tensor& total_weight,
257     const Tensor& input,
258     const Tensor& target,
259     const Tensor& weight,
260     int64_t reduction,
261     int64_t ignore_index) {
262   check_inputs_nll_loss2d(input, target, weight);
263   total_weight.resize_({});
264 
265   AT_DISPATCH_FLOATING_TYPES_AND2(
266       ScalarType::BFloat16,
267       ScalarType::Half,
268       input.scalar_type(),
269       "nll_loss2d_forward_out_frame",
270       [&] {
271         nll_loss2d_forward_out_frame<scalar_t>(
272             output,
273             total_weight,
274             input,
275             target,
276             weight,
277             reduction,
278             ignore_index);
279       });
280 }
281 
282 template <typename scalar_t>
nll_loss2d_backward_out_frame(Tensor & grad_input,const Tensor & grad_output,const Tensor & input,const Tensor & target,const Tensor & weight,int64_t reduction,int64_t ignore_index,const Tensor & total_weight)283 static void nll_loss2d_backward_out_frame(
284     Tensor& grad_input,
285     const Tensor& grad_output,
286     const Tensor& input,
287     const Tensor& target,
288     const Tensor& weight,
289     int64_t reduction,
290     int64_t ignore_index,
291     const Tensor& total_weight) {
292   auto weight_contiguous = optional_contiguous(weight);
293   const scalar_t* weight_data = optional_data<const scalar_t>(weight_contiguous);
294 
295   if (reduction == at::Reduction::None) {
296     check_gradout_shape_nll_loss2d(grad_output, target);
297 
298     const int64_t batch_size = input.size(0);
299     const int64_t H = input.size(2);
300     const int64_t W = input.size(3);
301 
302     auto grad_input_acc = grad_input.accessor<scalar_t, 4>();
303     auto grad_output_acc = grad_output.accessor<const scalar_t, 3>();
304     auto target_acc = target.accessor<const int64_t, 3>();
305 
306     at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
307       for (const auto b : c10::irange(start, end)) {
308         for (const auto h : c10::irange(H)) {
309           for (const auto w : c10::irange(W)) {
310             const int64_t cur_target = target_acc[b][h][w];
311             if (cur_target == ignore_index) {
312               continue;
313             }
314             const scalar_t value =
315                 -(weight_data ? weight_data[cur_target]
316                               : static_cast<scalar_t>(1));
317             const scalar_t grad_output_value = grad_output_acc[b][h][w];
318             grad_input_acc[b][cur_target][h][w] = value * grad_output_value;
319           }
320         }
321       }
322     });
323 
324     return;
325   }
326 
327   const scalar_t total_weight_value = *total_weight.const_data_ptr<scalar_t>();
328 
329   TORCH_CHECK(
330       grad_output.dim() <= 1 && grad_output.numel() == 1,
331       "Expected a single element grad_output tensor, but got: ",
332       grad_output.sizes());
333 
334   const scalar_t grad_output_value = *grad_output.const_data_ptr<scalar_t>();
335 
336   const auto target_contiguous = target.contiguous();
337   const int64_t* target_data = target_contiguous.const_data_ptr<int64_t>();
338 
339   scalar_t* grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
340 
341   const int64_t batch_size = input.size(0);
342   const int64_t n_classes = input.size(1);
343   const int64_t map_size = input.size(2) * input.size(3);
344   const int64_t sample_size = map_size * n_classes;
345 
346   const auto grad = -(reduction == Reduction::Mean ? grad_output_value / total_weight_value
347                                                    : grad_output_value);
348 
349   at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
350     for (const auto b : c10::irange(start, end)) {
351       for (const auto elem : c10::irange(map_size)) {
352         const int64_t t = target_data[b * map_size + elem];
353 
354         if (t != ignore_index) {
355           TORCH_CHECK_INDEX(t >= 0 && t < n_classes, "Target ", t, " is out of bounds.");
356 
357           const int64_t index = b * sample_size + t * map_size + elem;
358           grad_input_data[index] = weight_data != nullptr ? weight_data[t] * grad
359                                                           : grad;
360         }
361       }
362     }
363   });
364 }
365 
nll_loss2d_backward_out_cpu_template(Tensor & grad_input,const Tensor & grad_output,const Tensor & input,const Tensor & target,const Tensor & weight,int64_t reduction,int64_t ignore_index,const Tensor & total_weight)366 void nll_loss2d_backward_out_cpu_template(
367     Tensor& grad_input,
368     const Tensor& grad_output,
369     const Tensor& input,
370     const Tensor& target,
371     const Tensor& weight,
372     int64_t reduction,
373     int64_t ignore_index,
374     const Tensor& total_weight) {
375   check_inputs_nll_loss2d(input, target, weight);
376   grad_input.resize_as_(input);
377   grad_input.zero_();
378   TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous");
379   TORCH_CHECK(
380       total_weight.numel() == 1,
381       "expected total_weight to be a single element tensor, got: ",
382       total_weight.sizes(),
383       " (",
384       total_weight.numel(),
385       " elements)");
386 
387   AT_DISPATCH_FLOATING_TYPES_AND2(
388       ScalarType::BFloat16,
389       ScalarType::Half,
390       input.scalar_type(),
391       "nll_loss2d_backward_out_frame",
392       [&] {
393         nll_loss2d_backward_out_frame<scalar_t>(
394             grad_input,
395             grad_output,
396             input,
397             target,
398             weight,
399             reduction,
400             ignore_index,
401             total_weight);
402       });
403 }
404 
405 } // namespace
406 
nll_loss2d_forward_out_cpu(const Tensor & self,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction,int64_t ignore_index,Tensor & output,Tensor & total_weight)407 std::tuple<Tensor&, Tensor&> nll_loss2d_forward_out_cpu(const Tensor& self,
408     const Tensor& target, const std::optional<Tensor>& weight_opt,
409     int64_t reduction,
410     int64_t ignore_index,
411     Tensor& output,
412     Tensor& total_weight) {
413   // See [Note: hacky wrapper removal for optional tensor]
414   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
415   const Tensor& weight = *weight_maybe_owned;
416 
417   nll_loss2d_forward_out_cpu_template(
418       output, total_weight, self, target, weight, reduction, ignore_index);
419   return std::tuple<Tensor&, Tensor&>(output, total_weight);
420 }
421 
nll_loss2d_forward_cpu(const Tensor & self,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction,int64_t ignore_index)422 std::tuple<Tensor, Tensor> nll_loss2d_forward_cpu(
423     const Tensor& self,
424     const Tensor& target, const std::optional<Tensor>& weight_opt,
425     int64_t reduction,
426     int64_t ignore_index) {
427   // See [Note: hacky wrapper removal for optional tensor]
428   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
429   const Tensor& weight = *weight_maybe_owned;
430 
431   auto output = at::empty({0}, self.options());
432   auto total_weight = at::empty({0}, self.options());
433   at::native::nll_loss2d_forward_out_cpu(
434       self, target, weight, reduction, ignore_index, output, total_weight);
435   return std::make_tuple(output, total_weight);
436 }
437 
nll_loss2d_backward_out_cpu(const Tensor & grad_output,const Tensor & self,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction,int64_t ignore_index,const Tensor & total_weight,Tensor & grad_input)438 Tensor& nll_loss2d_backward_out_cpu(const Tensor& grad_output,
439     const Tensor& self,
440     const Tensor& target, const std::optional<Tensor>& weight_opt,
441     int64_t reduction,
442     int64_t ignore_index,
443     const Tensor& total_weight,
444     Tensor& grad_input) {
445   // See [Note: hacky wrapper removal for optional tensor]
446   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
447   const Tensor& weight = *weight_maybe_owned;
448 
449   nll_loss2d_backward_out_cpu_template(
450       grad_input,
451       grad_output,
452       self,
453       target,
454       weight,
455       reduction,
456       ignore_index,
457       total_weight);
458   return grad_input;
459 }
460 
nll_loss2d_backward_cpu(const Tensor & grad_output,const Tensor & self,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction,int64_t ignore_index,const Tensor & total_weight)461 Tensor nll_loss2d_backward_cpu(
462     const Tensor& grad_output,
463     const Tensor& self,
464     const Tensor& target, const std::optional<Tensor>& weight_opt,
465     int64_t reduction,
466     int64_t ignore_index,
467     const Tensor& total_weight) {
468   // See [Note: hacky wrapper removal for optional tensor]
469   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
470   const Tensor& weight = *weight_maybe_owned;
471 
472   auto grad_input = at::zeros_like(self);
473   at::native::nll_loss2d_backward_out_cpu(
474       grad_output,
475       self,
476       target,
477       weight,
478       reduction,
479       ignore_index,
480       total_weight,
481       grad_input);
482   return grad_input;
483 }
484 
nll_loss2d_out(const Tensor & self,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction,int64_t ignore_index,Tensor & output)485 Tensor & nll_loss2d_out(const Tensor & self, const Tensor & target, const std::optional<Tensor>& weight_opt, int64_t reduction, int64_t ignore_index, Tensor & output) {
486   // See [Note: hacky wrapper removal for optional tensor]
487   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
488   const Tensor& weight = *weight_maybe_owned;
489 
490   Tensor total_weight = at::empty({0}, self.options());
491   return std::get<0>(at::nll_loss2d_forward_out(output, total_weight, self, target, weight, reduction, ignore_index));
492 }
493 
nll_loss2d_symint(const Tensor & self,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction,c10::SymInt ignore_index)494 Tensor nll_loss2d_symint(const Tensor & self, const Tensor & target, const std::optional<Tensor>& weight_opt, int64_t reduction, c10::SymInt ignore_index) {
495   // See [Note: hacky wrapper removal for optional tensor]
496   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
497   const Tensor& weight = *weight_maybe_owned;
498 
499   return std::get<0>(at::nll_loss2d_forward_symint(self, target, weight, reduction, std::move(ignore_index)));
500 }
501 
502 } // namespace at::native
503