xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/LossNLL.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/TensorIndexing.h>
6 #include <ATen/TensorMeta.h>
7 #include <ATen/TensorOperators.h>
8 #include <ATen/TensorUtils.h>
9 #include <ATen/native/cpu/utils.h>
10 #include <ATen/native/Resize.h>
11 #include <c10/util/SmallBuffer.h>
12 #include <ATen/TensorSubclassLikeUtils.h>
13 
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #include <ATen/NativeFunctions.h>
17 #else
18 #include <ATen/ops/cross_entropy_loss_native.h>
19 #include <ATen/ops/empty.h>
20 #include <ATen/ops/log_softmax.h>
21 #include <ATen/ops/nll_loss.h>
22 #include <ATen/ops/nll_loss2d.h>
23 #include <ATen/ops/nll_loss_backward_native.h>
24 #include <ATen/ops/nll_loss_forward.h>
25 #include <ATen/ops/nll_loss_forward_native.h>
26 #include <ATen/ops/nll_loss_native.h>
27 #include <ATen/ops/nll_loss_nd.h>
28 #include <ATen/ops/nll_loss_nd_native.h>
29 #endif
30 
31 #include <c10/core/TensorOptions.h>
32 #include <c10/util/irange.h>
33 
34 #include <utility>
35 
36 namespace at::meta {
TORCH_META_FUNC(nll_loss_forward)37 TORCH_META_FUNC(nll_loss_forward)
38 (const Tensor& self,
39  const Tensor& target,
40  const OptionalTensorRef weight_opt,
41  int64_t reduction,
42  int64_t ignore_index) {
43   const Tensor& weight = weight_opt.getTensorRef();
44 
45   TORCH_CHECK(
46       self.dim() > 0 && self.dim() <= 2, "input tensor should be 1D or 2D");
47   TORCH_CHECK(
48       target.dim() <= 1,
49       "0D or 1D target tensor expected, multi-target not supported");
50 
51   auto no_batch_dim = self.dim() == 1  && target.dim() == 0;
52   TORCH_CHECK(
53       no_batch_dim || (self.size(0) == target.size(0)),
54       "size mismatch (got input: ",
55       self.sizes(),
56       ", target: ",
57       target.sizes(),
58       ")")
59 
60   const auto n_classes = self.size(-1);
61 
62   TORCH_CHECK(
63       !weight.defined() || (weight.dim() == 1 && weight.numel() == n_classes),
64       "weight tensor should be defined either for all ",
65       n_classes,
66       " classes or no classes"
67       " but got weight tensor of shape: ",
68       weight.sizes());
69 
70   const auto n_dims = self.dim();
71   const auto batch_size = self.size(0);
72 
73   if (reduction == Reduction::None && n_dims == 2) {
74     set_output_raw_strided(0, {batch_size}, {}, self.options());
75   } else {
76     // produce scalar output when reducing or input is 1d
77     set_output_raw_strided(0, {}, {}, self.options());
78   }
79 
80   set_output_raw_strided(1, {}, {}, self.options());
81 }
82 
TORCH_META_FUNC(nll_loss_backward)83 TORCH_META_FUNC(nll_loss_backward)
84 (const Tensor& grad_output,
85  const Tensor& self,
86  const Tensor& target,
87  OptionalTensorRef weight_opt,
88  int64_t reduction,
89  int64_t ignore_index,
90  const Tensor& total_weight) {
91   TORCH_CHECK(
92       self.dim() > 0 && self.dim() <= 2, "input tensor should be 1D or 2D");
93   TORCH_CHECK(
94       target.dim() <= 1,
95       "0D or 1D target tensor expected, multi-target not supported");
96 
97   auto no_batch_dim = self.dim() == 1  && target.dim() == 0;
98   TORCH_CHECK(
99       no_batch_dim || (self.size(0) == target.size(0)),
100       "size mismatch (got input: ",
101       self.sizes(),
102       ", target: ",
103       target.sizes(),
104       ")")
105   TORCH_CHECK(
106       total_weight.numel() == 1,
107       "expected total_weight to be a  single element tensor, got: ",
108       total_weight.sizes(),
109       " (",
110       total_weight.numel(),
111       " elements)");
112 
113   const auto& weight = weight_opt.getTensorRef();
114 
115   TORCH_CHECK(
116       !weight.defined() || weight.numel() == self.size(-1),
117       "weight tensor should be defined either for all or no classes");
118 
119   const auto n_dims = self.dim();
120 
121   if (reduction == Reduction::None && n_dims == 2) {
122     const auto batch_size = self.size(0);
123     check_dim_size(grad_output, 1, 0, batch_size);
124   } else {
125     TORCH_CHECK(
126         grad_output.dim() <= 1 && grad_output.numel() == 1,
127         "Expected a single element grad_output tensor, but got: ",
128         grad_output.sizes());
129   }
130 
131   set_output_raw_strided(0, self.sizes(), {}, self.options().memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT));
132 }
133 } // namespace at::meta
134 
135 namespace at::native {
136 
137 namespace {
138 
139 // Returns a contiguous tensor if the source tensor
140 // is defined. Otherwise returns the undefined
141 // source tensor unmodified.
optional_contiguous(const Tensor & source)142 inline Tensor optional_contiguous(const Tensor& source) {
143   return source.defined() ? source.contiguous() : source;
144 }
145 
146 // Returns the address of the first element of a tensor
147 // or nullptr if the tensor is undefined.
148 template <typename scalar_t>
optional_data(const Tensor & source)149 inline scalar_t* optional_data(const Tensor& source) {
150   if constexpr (std::is_const<scalar_t>::value) {
151     return source.defined() ? source.const_data_ptr<scalar_t>() : nullptr;
152   } else {
153     return source.defined() ? source.data_ptr<scalar_t>() : nullptr;
154   }
155 }
156 
157 template <typename scalar_t, typename target_t>
nll_loss_out_frame(const Tensor & output,const Tensor & total_weight,const Tensor & input,const Tensor & target,const Tensor & weight,int64_t reduction,int64_t ignore_index)158 static void nll_loss_out_frame(
159     const Tensor& output,
160     const Tensor& total_weight,
161     const Tensor& input,
162     const Tensor& target,
163     const Tensor& weight,
164     int64_t reduction,
165     int64_t ignore_index) {
166   const auto n_dims = input.dim();
167   const auto n_classes = input.size(-1);
168 
169   scalar_t* total_weight_data = total_weight.data_ptr<scalar_t>();
170   *total_weight_data = 0;
171 
172   auto weight_contiguous = optional_contiguous(weight);
173   const scalar_t* weight_data = optional_data<const scalar_t>(weight_contiguous);
174 
175   if (reduction == Reduction::None && n_dims == 2) {
176     const auto batch_size = input.size(0);
177     at::native::resize_output(output, {batch_size});
178 
179     auto input_acc = input.accessor<const scalar_t, 2>();
180     auto target_acc = target.accessor<const target_t, 1>();
181     auto output_acc = output.accessor<scalar_t, 1>();
182 
183     at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
184       for (const auto i : c10::irange(start, end)) {
185         const auto cur_target = target_acc[i];
186 
187         if (cur_target == ignore_index) {
188           output_acc[i] = 0;
189           continue;
190         }
191 
192         TORCH_CHECK_INDEX(
193             cur_target >= 0 && cur_target < n_classes,
194             "Target ",
195             cur_target,
196             " is out of bounds.");
197 
198         scalar_t cur_weight = weight_data != nullptr ? weight_data[cur_target]
199                                                      : static_cast<scalar_t>(1);
200         output_acc[i] = -input_acc[i][cur_target] * cur_weight;
201       }
202     });
203 
204     return;
205   }
206 
207   // produce scalar outputs for the reduction case
208   at::native::resize_output(output, {});
209 
210   if (target.numel() == 0) {
211     // Here target (and input) have zero elements
212     // Mean reduction on empty tensors produces NaN. See the discussion in
213     // https://github.com/pytorch/pytorch/pull/64572#issuecomment-926504162
214     if (reduction == Reduction::Mean) {
215       output.fill_(std::numeric_limits<double>::quiet_NaN());
216     } else {
217       output.zero_();
218     }
219     total_weight.zero_();
220     return;
221   }
222 
223   auto input_contiguous = input.contiguous();
224   auto target_contiguous = target.contiguous();
225 
226   const scalar_t* input_data = input_contiguous.const_data_ptr<scalar_t>();
227   const target_t* target_data = target_contiguous.const_data_ptr<target_t>();
228 
229   const int64_t ndim = input.dim();
230   const int64_t batch_size = ndim == 1 ? 1 : input.size(0);
231 
232   constexpr int64_t cascade_sum_num_levels = 8;
233   const int64_t level_power =
234       std::max(int64_t(4), utils::CeilLog2(batch_size) / cascade_sum_num_levels);
235   const int64_t level_step = (1 << level_power);
236   const int64_t level_mask = level_step - 1;
237 
238   int64_t num_ignored = 0;
239 
240   // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
241   scalar_t weight_partial_sums[cascade_sum_num_levels] = {0};
242   // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
243   scalar_t loss_partial_sums[cascade_sum_num_levels] = {0};
244   for (const auto b : c10::irange(batch_size)) {
245     const int64_t cur_target = target_data[b];
246     if (cur_target == ignore_index) {
247       ++num_ignored;
248       continue;
249     }
250 
251     TORCH_CHECK_INDEX(
252         cur_target >= 0 && cur_target < n_classes,
253         "Target ",
254         cur_target,
255         " is out of bounds.");
256 
257     const auto data = input_data[b * n_classes + cur_target];
258     if (weight_data) {
259       const scalar_t weight_val = weight_data[cur_target];
260       loss_partial_sums[0] -= data * weight_val;
261       weight_partial_sums[0] += weight_val;
262     } else {
263       loss_partial_sums[0] -= data;
264     }
265 
266     for (int64_t j = 0; j + 1 < cascade_sum_num_levels; ++j) {
267       const auto mask = (level_mask << (j * level_power));
268       if (C10_LIKELY((b & mask) != 0)) {
269         break;
270       }
271 
272       weight_partial_sums[j + 1] += weight_partial_sums[j];
273       loss_partial_sums[j + 1] += loss_partial_sums[j];
274 
275       weight_partial_sums[j] = 0;
276       loss_partial_sums[j] = 0;
277     }
278   }
279 
280   const scalar_t total_weight_val = !weight_data ?
281     static_cast<scalar_t>(batch_size - num_ignored) :
282     std::accumulate(std::begin(weight_partial_sums),
283                     std::end(weight_partial_sums),
284                     scalar_t{0});
285 
286   scalar_t output_val = std::accumulate(std::begin(loss_partial_sums),
287                                         std::end(loss_partial_sums),
288                                         scalar_t{0});
289 
290   if (reduction == Reduction::Mean) {
291     output_val /= total_weight_val;
292   }
293 
294   // write result to output tensors
295   *output.data_ptr<scalar_t>() = output_val;
296   *total_weight_data = total_weight_val;
297 }
298 
nll_loss_forward_out_cpu_template(const Tensor & output,const Tensor & total_weight,const Tensor & input,const Tensor & target,const Tensor & weight,int64_t reduction,int64_t ignore_index)299 void nll_loss_forward_out_cpu_template(
300     const Tensor& output,
301     const Tensor& total_weight,
302     const Tensor& input,
303     const Tensor& target,
304     const Tensor& weight,
305     int64_t reduction,
306     int64_t ignore_index) {
307   AT_DISPATCH_FLOATING_TYPES_AND2(
308       ScalarType::BFloat16,
309       ScalarType::Half,
310       input.scalar_type(),
311       "nll_loss_out_frame",
312       [&] {
313         if (target.scalar_type() == kByte) {
314           nll_loss_out_frame<scalar_t, uint8_t>(
315               output,
316               total_weight,
317               input,
318               target,
319               weight,
320               reduction,
321               ignore_index);
322         } else {
323           // assumed to be int64
324           nll_loss_out_frame<scalar_t, int64_t>(
325               output,
326               total_weight,
327               input,
328               target,
329               weight,
330               reduction,
331               ignore_index);
332         }
333       });
334 }
335 
336 template <typename scalar_t, typename target_t>
nll_loss_backward_out_frame(const Tensor & grad_input,const Tensor & grad_output,const Tensor & input,const Tensor & target,const Tensor & weight,int64_t reduction,int64_t ignore_index,const Tensor & total_weight)337 static void nll_loss_backward_out_frame(
338     const Tensor& grad_input,
339     const Tensor& grad_output,
340     const Tensor& input,
341     const Tensor& target,
342     const Tensor& weight,
343     int64_t reduction,
344     int64_t ignore_index,
345     const Tensor& total_weight) {
346   const auto n_dims = input.dim();
347   const auto n_classes = input.size(-1);
348 
349   auto target_ = target;
350   if (target.dim() == 0) {
351     target_ = target.unsqueeze(0);
352   }
353   auto target_acc = target_.accessor<const target_t, 1>();
354 
355   auto weight_contiguous = optional_contiguous(weight);
356   const scalar_t* weight_data = optional_data<const scalar_t>(weight_contiguous);
357 
358   if (reduction == Reduction::None && n_dims == 2) {
359     const auto batch_size = input.size(0);
360     auto grad_input_acc = grad_input.accessor<scalar_t, 2>();
361     auto grad_output_acc = grad_output.accessor<const scalar_t, 1>();
362     at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
363       for (const auto i : c10::irange(start, end)) {
364         auto cur_target = target_acc[i];
365         if (cur_target == ignore_index) {
366           continue;
367         }
368         const scalar_t w =
369             weight_data ? weight_data[cur_target] : static_cast<scalar_t>(1);
370         grad_input_acc[i][cur_target] = -w * grad_output_acc[i];
371       }
372     });
373     return;
374   }
375 
376   const scalar_t total_weight_value = *total_weight.const_data_ptr<scalar_t>();
377 
378   const scalar_t grad_output_value = *grad_output.const_data_ptr<scalar_t>();
379 
380   if (input.dim() == 1) {
381     auto grad_input_acc = grad_input.accessor<scalar_t, 1>();
382 
383     const auto t = target_acc[0];
384     if (t != ignore_index) {
385       TORCH_CHECK_INDEX(t >= 0 && t < n_classes, "Target ", t, " is out of bounds.");
386       const auto grad = -(reduction == Reduction::Mean ? grad_output_value / total_weight_value
387                                                        : grad_output_value);
388       grad_input_acc[t] = weight_data != nullptr ? weight_data[t] * grad
389                                                  : grad;
390     }
391   } else if (input.dim() == 2) {
392     auto grad_input_acc = grad_input.accessor<scalar_t, 2>();
393     const auto grad = -(reduction == Reduction::Mean ? grad_output_value / total_weight_value
394                                                      : grad_output_value);
395 
396     const auto batch_size = input.size(0);
397 
398     at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
399       for (const auto i : c10::irange(start, end)) {
400         const auto t = target_acc[i];
401         if (t != ignore_index) {
402           TORCH_CHECK_INDEX(t >= 0 && t < n_classes, "Target ", t, " is out of bounds.");
403           grad_input_acc[i][t] = weight_data != nullptr ? weight_data[t] * grad
404                                                         : grad;
405         }
406       }
407     });
408   }
409 }
410 
nll_loss_backward_out_cpu_template(const Tensor & grad_input,const Tensor & grad_output,const Tensor & input,const Tensor & target,const Tensor & weight,int64_t reduction,int64_t ignore_index,const Tensor & total_weight)411 void nll_loss_backward_out_cpu_template(
412     const Tensor& grad_input,
413     const Tensor& grad_output,
414     const Tensor& input,
415     const Tensor& target,
416     const Tensor& weight,
417     int64_t reduction,
418     int64_t ignore_index,
419     const Tensor& total_weight) {
420   grad_input.zero_();
421 
422   AT_DISPATCH_FLOATING_TYPES_AND2(
423       ScalarType::BFloat16,
424       ScalarType::Half,
425       input.scalar_type(),
426       "nll_loss_backward_out_frame",
427       [&] {
428         if (target.scalar_type() == kByte) {
429           nll_loss_backward_out_frame<scalar_t, uint8_t>(
430               grad_input,
431               grad_output,
432               input,
433               target,
434               weight,
435               reduction,
436               ignore_index,
437               total_weight);
438         } else {
439           // assumed to be uint64
440           nll_loss_backward_out_frame<scalar_t, int64_t>(
441               grad_input,
442               grad_output,
443               input,
444               target,
445               weight,
446               reduction,
447               ignore_index,
448               total_weight);
449         }
450       });
451 }
452 
453 } // namespace
454 
TORCH_IMPL_FUNC(nll_loss_forward_out_cpu)455 TORCH_IMPL_FUNC(nll_loss_forward_out_cpu)
456 (const Tensor& self,
457  const Tensor& target,
458  const OptionalTensorRef weight_opt,
459  int64_t reduction,
460  int64_t ignore_index,
461  const Tensor& output,
462  const Tensor& total_weight) {
463   const Tensor& weight = weight_opt.getTensorRef();
464   nll_loss_forward_out_cpu_template(
465       output, total_weight, self, target, weight, reduction, ignore_index);
466 }
467 
TORCH_IMPL_FUNC(nll_loss_backward_out_cpu)468 TORCH_IMPL_FUNC(nll_loss_backward_out_cpu)
469 (const Tensor& grad_output,
470  const Tensor& self,
471  const Tensor& target,
472  OptionalTensorRef weight_opt,
473  int64_t reduction,
474  int64_t ignore_index,
475  const Tensor& total_weight,
476  const Tensor& grad_input
477 ) {
478   const Tensor& weight = weight_opt.getTensorRef();
479   nll_loss_backward_out_cpu_template(
480       grad_input,
481       grad_output,
482       self,
483       target,
484       weight,
485       reduction,
486       ignore_index,
487       total_weight);
488 }
489 
cross_entropy_loss_prob_target(const Tensor & self,const Tensor & target_,const Tensor & weight,int64_t reduction,double label_smoothing)490 static Tensor cross_entropy_loss_prob_target(
491     const Tensor& self,
492     const Tensor& target_,
493     const Tensor& weight,
494     int64_t reduction,
495     double label_smoothing) {
496   const auto class_dim = self.dim() == 1 ? 0 : 1;
497   const auto n_classes = self.size(class_dim);
498   TORCH_CHECK(
499       !weight.defined() || (weight.dim() == 1 && weight.numel() == n_classes),
500       "cross_entropy: weight tensor should be defined either for all ",
501       n_classes,
502       " classes or no classes"
503       " but got weight tensor of shape: ",
504       weight.sizes());
505 
506   auto input = at::log_softmax(self, class_dim, self.scalar_type());
507   Tensor target;
508 
509   if (label_smoothing > 0.0) {
510     TORCH_CHECK(label_smoothing <= 1.0, "label_smoothing must be between 0.0 and 1.0. Got: ", label_smoothing);
511     target = target_ * (1 - label_smoothing) + label_smoothing / n_classes;
512   } else {
513     target = target_;
514   }
515 
516   if (weight.defined()) {
517     // Expand weight to the correct number of dims for broadcasting with input / target
518     Tensor weight_ = weight;
519     if (input.dim() > 1) {
520         auto weight_broadcast_shape = SmallBuffer<int64_t, 5>(input.dim());
521         std::fill(weight_broadcast_shape.begin(), weight_broadcast_shape.end(), 1);
522         weight_broadcast_shape[1] = weight.size(0);
523         weight_ = weight.view(weight_broadcast_shape);
524     }
525 
526     switch (reduction) {
527       case Reduction::Mean:
528         if (input.sym_numel()==0){
529           return -(input * target * weight_).sum().fill_(std::numeric_limits<double>::quiet_NaN());
530         } else {
531           return -(input * target * weight_).sum() / (input.sym_numel() / n_classes);
532         }
533       case Reduction::Sum:
534         return -(input * target * weight_).sum();
535       case Reduction::None:
536         return -(input * target * weight_).sum(class_dim);
537       default:
538         TORCH_CHECK(false, "Invalid reduction type encountered in cross_entropy: ", reduction);
539     }
540   } else {
541     switch (reduction) {
542       case Reduction::Mean:
543         if (input.sym_numel()==0){
544           return -(input * target).sum().fill_(std::numeric_limits<double>::quiet_NaN());
545         } else {
546           return -(input * target).sum() / (input.sym_numel() / n_classes);
547         }
548       case Reduction::Sum:
549         return -(input * target).sum();
550       case Reduction::None:
551         return -(input * target).sum(class_dim);
552       default:
553         TORCH_CHECK(false, "Invalid reduction type encountered in cross_entropy: ", reduction);
554     }
555   }
556 }
557 
cross_entropy_loss_label_smoothing(const Tensor & self,const Tensor & target,const Tensor & weight,int64_t reduction,c10::SymInt ignore_index,double label_smoothing)558 static Tensor cross_entropy_loss_label_smoothing(
559     const Tensor& self,
560     const Tensor& target,
561     const Tensor& weight,
562     int64_t reduction,
563     c10::SymInt ignore_index,
564     double label_smoothing) {
565     auto class_dim = self.dim() == 1 ? 0 : 1;
566     auto input = at::log_softmax(self, class_dim, self.scalar_type());
567     auto nllloss = at::nll_loss_nd_symint(input, target, weight, reduction, ignore_index);
568 
569     auto n_classes = input.sym_size(class_dim);
570 
571     Tensor smooth_loss;
572     if (weight.defined()) {
573       // Expand weight to the correct number of dims for broadcasting with input / target
574       auto weight_broadcast_shape = SmallBuffer<int64_t, 5>(input.dim());
575       std::fill(weight_broadcast_shape.begin(), weight_broadcast_shape.end(), 1);
576       weight_broadcast_shape[class_dim] = weight.size(0);
577       Tensor weight_ = weight.view(weight_broadcast_shape);
578 
579       smooth_loss = -(input * weight_).sum(class_dim);
580     } else {
581       smooth_loss = -input.sum(class_dim);
582     }
583 
584     auto ignore_mask = target == std::move(ignore_index);
585     smooth_loss.masked_fill_(ignore_mask, 0.0);
586 
587     Tensor ret;
588     switch (reduction) {
589       case Reduction::Mean:
590         if (weight.defined()) {
591           if (isTensorSubclassLike(weight)){
592             // we will collect weights from 0 index which is always valid
593             // and mask them out if they are ignored
594             auto filtered_target = target.masked_fill(ignore_mask, 0);
595             auto tgt_weights = weight.gather(0, filtered_target.flatten());
596             auto weight_sum =
597                 tgt_weights.masked_fill_(ignore_mask.flatten(), 0).sum();
598             ret = smooth_loss.sum() / weight_sum;
599           } else {
600             // TODO: This code can path can be removed if #61309 is resolved
601             // loss is normalized by the weights to be consistent with
602             // nll_loss_nd
603             ret = smooth_loss.sum() /
604                 weight.gather(0, target.masked_select(~ignore_mask).flatten())
605                     .sum();
606           }
607         } else {
608           auto true_mask = ~ignore_mask;
609           ret = smooth_loss.sum()/ true_mask.sum();
610         }
611         break;
612       case Reduction::Sum:
613         ret = smooth_loss.sum();
614         break;
615       case Reduction::None:
616         ret = smooth_loss;
617         break;
618       default:
619         TORCH_CHECK(false, "Invalid reduction type encountered in cross_entropy: ", reduction);
620     }
621     return (1 - label_smoothing) * nllloss + ret * (label_smoothing / n_classes);
622 }
623 
cross_entropy_loss_symint(const Tensor & self,const Tensor & target,const std::optional<Tensor> & weight,int64_t reduction,c10::SymInt ignore_index,double label_smoothing)624 Tensor cross_entropy_loss_symint(
625     const Tensor& self,
626     const Tensor& target,
627     const std::optional<Tensor>& weight,
628     int64_t reduction,
629     c10::SymInt ignore_index,
630     double label_smoothing) {
631   Tensor ret;
632   if (self.sym_sizes() == target.sym_sizes()) {
633     // Assume soft targets when input and target shapes are the same
634     TORCH_CHECK(at::isFloatingType(target.scalar_type()),
635         "Expected floating point type for target with class probabilities, got ", target.scalar_type());
636     TORCH_CHECK(ignore_index < 0, "ignore_index is not supported for floating point target");
637 
638     // See [Note: hacky wrapper removal for optional tensor]
639     c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight);
640     const Tensor& weight_ = *weight_maybe_owned;
641     ret = cross_entropy_loss_prob_target(self, target, weight_, reduction, label_smoothing);
642   } else if (label_smoothing > 0.0) {
643     TORCH_CHECK(label_smoothing <= 1.0, "label_smoothing must be between 0.0 and 1.0. Got: ", label_smoothing);
644 
645     // See [Note: hacky wrapper removal for optional tensor]
646     c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight);
647     const Tensor& weight_ = *weight_maybe_owned;
648     ret = cross_entropy_loss_label_smoothing(self, target, weight_, reduction, std::move(ignore_index), label_smoothing);
649   } else {
650     auto class_dim = self.dim() == 1 ? 0 : 1;
651     ret = at::nll_loss_nd_symint(
652         at::log_softmax(self, class_dim, self.scalar_type()),
653         target,
654         weight,
655         reduction,
656         std::move(ignore_index));
657   }
658   return ret;
659 }
660 
nll_loss_out(const Tensor & self,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction,int64_t ignore_index,Tensor & output)661 Tensor & nll_loss_out(const Tensor & self, const Tensor & target, const std::optional<Tensor>& weight_opt, int64_t reduction, int64_t ignore_index, Tensor & output) {
662   // See [Note: hacky wrapper removal for optional tensor]
663   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
664   const Tensor& weight = *weight_maybe_owned;
665 
666   Tensor total_weight = at::empty({0}, self.options());
667   return std::get<0>(at::nll_loss_forward_out(output, total_weight, self, target, weight, reduction, ignore_index));
668 }
669 
nll_loss_symint(const Tensor & self,const Tensor & target,const std::optional<Tensor> & weight_opt,int64_t reduction,c10::SymInt ignore_index)670 Tensor nll_loss_symint(const Tensor & self, const Tensor & target, const std::optional<Tensor>& weight_opt, int64_t reduction, c10::SymInt ignore_index) {
671   // See [Note: hacky wrapper removal for optional tensor]
672   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
673   const Tensor& weight = *weight_maybe_owned;
674 
675   return std::get<0>(at::nll_loss_forward_symint(self, target, weight, reduction, std::move(ignore_index)));
676 }
677 
nll_loss_nd_symint(const Tensor & self,const Tensor & target,const std::optional<Tensor> & weight,int64_t reduction,c10::SymInt ignore_index)678 Tensor nll_loss_nd_symint(
679     const Tensor& self,
680     const Tensor& target,
681     const std::optional<Tensor>& weight,
682     int64_t reduction,
683     c10::SymInt ignore_index) {
684   if (self.dim() < 1) {
685     TORCH_CHECK_VALUE(
686         false, "Expected 1 or more dimensions (got ", self.dim(), ")");
687   }
688 
689   if (self.dim() != 1 && self.sym_sizes()[0] != target.sym_sizes()[0]) {
690     TORCH_CHECK_VALUE(
691         false,
692         "Expected input batch_size (",
693         self.sym_sizes()[0],
694         ") to match target batch_size (",
695         target.sym_sizes()[0],
696         ").");
697   }
698 
699   Tensor ret;
700   Tensor input_ = self;
701   Tensor target_ = target;
702   if (input_.dim() == 1 || input_.dim() == 2) {
703     ret = at::nll_loss_symint(input_, target_, weight, reduction, std::move(ignore_index));
704   } else if (input_.dim() == 4) {
705     ret = at::nll_loss2d_symint(input_, target_, weight, reduction, std::move(ignore_index));
706   } else {
707     // dim == 3 or dim > 4
708     auto n = input_.sym_sizes()[0];
709     auto c = input_.sym_sizes()[1];
710     auto out_size = input_.sym_sizes().slice(2).vec();
711     out_size.insert(out_size.begin(), n);
712     if (target_.sym_sizes().slice(1) != input_.sym_sizes().slice(2)) {
713       TORCH_CHECK(
714           false,
715           "Expected target size ",
716           SymIntArrayRef(out_size),
717           ", got ",
718           target_.sym_sizes());
719     }
720     input_ = input_.contiguous();
721     target_ = target_.contiguous();
722     // support empty batches, see #15870
723     if (input_.sym_numel() > 0) {
724       input_ = input_.view_symint({n, std::move(c), 1, -1});
725     } else {
726       input_ = input_.view_symint({n, std::move(c), 0, 0});
727     }
728     if (target_.sym_numel() > 0) {
729       target_ = target_.view_symint({std::move(n), 1, -1});
730     } else {
731       target_ = target_.view_symint({std::move(n), 0, 0});
732     }
733     if (reduction != Reduction::None) {
734       ret = at::nll_loss2d_symint(input_, target_, weight, reduction, std::move(ignore_index));
735     } else {
736       auto out =
737           at::nll_loss2d_symint(input_, target_, weight, reduction, std::move(ignore_index));
738       ret = out.view_symint(out_size);
739     }
740   }
741   return ret;
742 }
743 
744 } // namespace at::native
745